use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use blake2::{Blake2b512, Digest};
use crate::error::DurableResult;
use crate::handlers::StepContext;
use crate::sealed::Sealed;
use crate::state::ExecutionState;
use crate::traits::DurableValue;
use crate::types::OperationId;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OperationIdentifier {
pub operation_id: String,
pub parent_id: Option<String>,
pub name: Option<String>,
}
impl OperationIdentifier {
pub fn new(
operation_id: impl Into<String>,
parent_id: Option<String>,
name: Option<String>,
) -> Self {
Self {
operation_id: operation_id.into(),
parent_id,
name,
}
}
pub fn root(operation_id: impl Into<String>) -> Self {
Self {
operation_id: operation_id.into(),
parent_id: None,
name: None,
}
}
pub fn with_parent(operation_id: impl Into<String>, parent_id: impl Into<String>) -> Self {
Self {
operation_id: operation_id.into(),
parent_id: Some(parent_id.into()),
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
#[inline]
pub fn operation_id_typed(&self) -> OperationId {
OperationId::from(self.operation_id.clone())
}
#[inline]
pub fn parent_id_typed(&self) -> Option<OperationId> {
self.parent_id
.as_ref()
.map(|id| OperationId::from(id.clone()))
}
pub fn apply_to(
&self,
mut update: crate::operation::OperationUpdate,
) -> crate::operation::OperationUpdate {
if let Some(ref parent_id) = self.parent_id {
update = update.with_parent_id(parent_id);
}
if let Some(ref name) = self.name {
update = update.with_name(name);
}
update
}
}
impl std::fmt::Display for OperationIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(ref name) = self.name {
write!(f, "{}({})", name, self.operation_id)
} else {
write!(f, "{}", self.operation_id)
}
}
}
#[derive(Debug)]
pub struct OperationIdGenerator {
base_id: String,
step_counter: AtomicU64,
}
impl OperationIdGenerator {
pub fn new(base_id: impl Into<String>) -> Self {
Self {
base_id: base_id.into(),
step_counter: AtomicU64::new(0),
}
}
pub fn with_counter(base_id: impl Into<String>, initial_counter: u64) -> Self {
Self {
base_id: base_id.into(),
step_counter: AtomicU64::new(initial_counter),
}
}
pub fn next_id(&self) -> String {
let counter = self.step_counter.fetch_add(1, Ordering::Relaxed);
generate_operation_id(&self.base_id, counter)
}
pub fn id_for_counter(&self, counter: u64) -> String {
generate_operation_id(&self.base_id, counter)
}
pub fn current_counter(&self) -> u64 {
self.step_counter.load(Ordering::Relaxed)
}
pub fn base_id(&self) -> &str {
&self.base_id
}
pub fn create_child(&self, parent_operation_id: impl Into<String>) -> Self {
Self::new(parent_operation_id)
}
}
impl Clone for OperationIdGenerator {
fn clone(&self) -> Self {
Self {
base_id: self.base_id.clone(),
step_counter: AtomicU64::new(self.step_counter.load(Ordering::Relaxed)),
}
}
}
pub fn generate_operation_id(base_id: &str, counter: u64) -> String {
let mut hasher = Blake2b512::new();
hasher.update(base_id.as_bytes());
hasher.update(counter.to_le_bytes());
let result = hasher.finalize();
hex::encode(&result[..16])
}
#[allow(private_bounds)]
pub trait Logger: Sealed + Send + Sync {
fn debug(&self, message: &str, info: &LogInfo);
fn info(&self, message: &str, info: &LogInfo);
fn warn(&self, message: &str, info: &LogInfo);
fn error(&self, message: &str, info: &LogInfo);
}
#[derive(Debug, Clone, Default)]
pub struct LogInfo {
pub durable_execution_arn: Option<String>,
pub operation_id: Option<String>,
pub parent_id: Option<String>,
pub is_replay: bool,
pub extra: Vec<(String, String)>,
}
impl LogInfo {
pub fn new(durable_execution_arn: impl Into<String>) -> Self {
Self {
durable_execution_arn: Some(durable_execution_arn.into()),
..Default::default()
}
}
pub fn with_operation_id(mut self, operation_id: impl Into<String>) -> Self {
self.operation_id = Some(operation_id.into());
self
}
pub fn with_parent_id(mut self, parent_id: impl Into<String>) -> Self {
self.parent_id = Some(parent_id.into());
self
}
pub fn with_replay(mut self, is_replay: bool) -> Self {
self.is_replay = is_replay;
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.push((key.into(), value.into()));
self
}
}
pub fn create_operation_span(
operation_type: &str,
op_id: &OperationIdentifier,
durable_execution_arn: &str,
) -> tracing::Span {
tracing::info_span!(
"durable_operation",
operation_type = %operation_type,
operation_id = %op_id.operation_id,
parent_id = ?op_id.parent_id,
name = ?op_id.name,
durable_execution_arn = %durable_execution_arn,
status = tracing::field::Empty,
)
}
#[derive(Debug, Clone, Default)]
pub struct TracingLogger;
impl Sealed for TracingLogger {}
impl TracingLogger {
fn format_extra_fields(extra: &[(String, String)]) -> String {
if extra.is_empty() {
String::new()
} else {
extra
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(", ")
}
}
}
impl Logger for TracingLogger {
fn debug(&self, message: &str, info: &LogInfo) {
let extra_fields = Self::format_extra_fields(&info.extra);
tracing::debug!(
durable_execution_arn = ?info.durable_execution_arn,
operation_id = ?info.operation_id,
parent_id = ?info.parent_id,
is_replay = info.is_replay,
extra = %extra_fields,
"{}",
message
);
}
fn info(&self, message: &str, info: &LogInfo) {
let extra_fields = Self::format_extra_fields(&info.extra);
tracing::info!(
durable_execution_arn = ?info.durable_execution_arn,
operation_id = ?info.operation_id,
parent_id = ?info.parent_id,
is_replay = info.is_replay,
extra = %extra_fields,
"{}",
message
);
}
fn warn(&self, message: &str, info: &LogInfo) {
let extra_fields = Self::format_extra_fields(&info.extra);
tracing::warn!(
durable_execution_arn = ?info.durable_execution_arn,
operation_id = ?info.operation_id,
parent_id = ?info.parent_id,
is_replay = info.is_replay,
extra = %extra_fields,
"{}",
message
);
}
fn error(&self, message: &str, info: &LogInfo) {
let extra_fields = Self::format_extra_fields(&info.extra);
tracing::error!(
durable_execution_arn = ?info.durable_execution_arn,
operation_id = ?info.operation_id,
parent_id = ?info.parent_id,
is_replay = info.is_replay,
extra = %extra_fields,
"{}",
message
);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ReplayLoggingConfig {
#[default]
SuppressAll,
AllowAll,
ErrorsOnly,
WarningsAndErrors,
}
pub struct ReplayAwareLogger {
inner: Arc<dyn Logger>,
config: ReplayLoggingConfig,
}
impl Sealed for ReplayAwareLogger {}
impl ReplayAwareLogger {
pub fn new(inner: Arc<dyn Logger>, config: ReplayLoggingConfig) -> Self {
Self { inner, config }
}
pub fn suppress_replay(inner: Arc<dyn Logger>) -> Self {
Self::new(inner, ReplayLoggingConfig::SuppressAll)
}
pub fn allow_all(inner: Arc<dyn Logger>) -> Self {
Self::new(inner, ReplayLoggingConfig::AllowAll)
}
pub fn config(&self) -> ReplayLoggingConfig {
self.config
}
pub fn inner(&self) -> &Arc<dyn Logger> {
&self.inner
}
fn should_suppress(&self, info: &LogInfo, level: LogLevel) -> bool {
if !info.is_replay {
return false;
}
match self.config {
ReplayLoggingConfig::SuppressAll => true,
ReplayLoggingConfig::AllowAll => false,
ReplayLoggingConfig::ErrorsOnly => level != LogLevel::Error,
ReplayLoggingConfig::WarningsAndErrors => {
level != LogLevel::Error && level != LogLevel::Warn
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LogLevel {
Debug,
Info,
Warn,
Error,
}
impl std::fmt::Debug for ReplayAwareLogger {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReplayAwareLogger")
.field("config", &self.config)
.finish()
}
}
impl Logger for ReplayAwareLogger {
fn debug(&self, message: &str, info: &LogInfo) {
if !self.should_suppress(info, LogLevel::Debug) {
self.inner.debug(message, info);
}
}
fn info(&self, message: &str, info: &LogInfo) {
if !self.should_suppress(info, LogLevel::Info) {
self.inner.info(message, info);
}
}
fn warn(&self, message: &str, info: &LogInfo) {
if !self.should_suppress(info, LogLevel::Warn) {
self.inner.warn(message, info);
}
}
fn error(&self, message: &str, info: &LogInfo) {
if !self.should_suppress(info, LogLevel::Error) {
self.inner.error(message, info);
}
}
}
pub struct CustomLogger<D, I, W, E>
where
D: Fn(&str, &LogInfo) + Send + Sync,
I: Fn(&str, &LogInfo) + Send + Sync,
W: Fn(&str, &LogInfo) + Send + Sync,
E: Fn(&str, &LogInfo) + Send + Sync,
{
debug_fn: D,
info_fn: I,
warn_fn: W,
error_fn: E,
}
impl<D, I, W, E> Sealed for CustomLogger<D, I, W, E>
where
D: Fn(&str, &LogInfo) + Send + Sync,
I: Fn(&str, &LogInfo) + Send + Sync,
W: Fn(&str, &LogInfo) + Send + Sync,
E: Fn(&str, &LogInfo) + Send + Sync,
{
}
impl<D, I, W, E> Logger for CustomLogger<D, I, W, E>
where
D: Fn(&str, &LogInfo) + Send + Sync,
I: Fn(&str, &LogInfo) + Send + Sync,
W: Fn(&str, &LogInfo) + Send + Sync,
E: Fn(&str, &LogInfo) + Send + Sync,
{
fn debug(&self, message: &str, info: &LogInfo) {
(self.debug_fn)(message, info);
}
fn info(&self, message: &str, info: &LogInfo) {
(self.info_fn)(message, info);
}
fn warn(&self, message: &str, info: &LogInfo) {
(self.warn_fn)(message, info);
}
fn error(&self, message: &str, info: &LogInfo) {
(self.error_fn)(message, info);
}
}
impl<D, I, W, E> std::fmt::Debug for CustomLogger<D, I, W, E>
where
D: Fn(&str, &LogInfo) + Send + Sync,
I: Fn(&str, &LogInfo) + Send + Sync,
W: Fn(&str, &LogInfo) + Send + Sync,
E: Fn(&str, &LogInfo) + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomLogger").finish()
}
}
pub fn custom_logger<D, I, W, E>(
debug_fn: D,
info_fn: I,
warn_fn: W,
error_fn: E,
) -> CustomLogger<D, I, W, E>
where
D: Fn(&str, &LogInfo) + Send + Sync,
I: Fn(&str, &LogInfo) + Send + Sync,
W: Fn(&str, &LogInfo) + Send + Sync,
E: Fn(&str, &LogInfo) + Send + Sync,
{
CustomLogger {
debug_fn,
info_fn,
warn_fn,
error_fn,
}
}
pub fn simple_custom_logger<F>(log_fn: F) -> impl Logger
where
F: Fn(&str, &str, &LogInfo) + Send + Sync + Clone + 'static,
{
let debug_fn = log_fn.clone();
let info_fn = log_fn.clone();
let warn_fn = log_fn.clone();
let error_fn = log_fn;
custom_logger(
move |msg, info| debug_fn("DEBUG", msg, info),
move |msg, info| info_fn("INFO", msg, info),
move |msg, info| warn_fn("WARN", msg, info),
move |msg, info| error_fn("ERROR", msg, info),
)
}
pub struct DurableContext {
state: Arc<ExecutionState>,
lambda_context: Option<lambda_runtime::Context>,
parent_id: Option<String>,
id_generator: Arc<OperationIdGenerator>,
logger: Arc<RwLock<Arc<dyn Logger>>>,
}
static_assertions::assert_impl_all!(DurableContext: Send, Sync);
impl DurableContext {
pub fn new(state: Arc<ExecutionState>) -> Self {
let base_id = state.durable_execution_arn().to_string();
Self {
state,
lambda_context: None,
parent_id: None,
id_generator: Arc::new(OperationIdGenerator::new(base_id)),
logger: Arc::new(RwLock::new(Arc::new(TracingLogger))),
}
}
pub fn from_lambda_context(
state: Arc<ExecutionState>,
lambda_context: lambda_runtime::Context,
) -> Self {
let base_id = state.durable_execution_arn().to_string();
Self {
state,
lambda_context: Some(lambda_context),
parent_id: None,
id_generator: Arc::new(OperationIdGenerator::new(base_id)),
logger: Arc::new(RwLock::new(Arc::new(TracingLogger))),
}
}
pub fn create_child_context(&self, parent_operation_id: impl Into<String>) -> Self {
let parent_id = parent_operation_id.into();
Self {
state: self.state.clone(),
lambda_context: self.lambda_context.clone(),
parent_id: Some(parent_id.clone()),
id_generator: Arc::new(OperationIdGenerator::new(parent_id)),
logger: self.logger.clone(),
}
}
pub fn set_logger(&mut self, logger: Arc<dyn Logger>) {
*self.logger.write().unwrap() = logger;
}
pub fn with_logger(self, logger: Arc<dyn Logger>) -> Self {
*self.logger.write().unwrap() = logger;
self
}
pub fn configure_logger(&self, logger: Arc<dyn Logger>) {
*self.logger.write().unwrap() = logger;
}
pub fn state(&self) -> &Arc<ExecutionState> {
&self.state
}
pub fn durable_execution_arn(&self) -> &str {
self.state.durable_execution_arn()
}
pub fn parent_id(&self) -> Option<&str> {
self.parent_id.as_deref()
}
pub fn lambda_context(&self) -> Option<&lambda_runtime::Context> {
self.lambda_context.as_ref()
}
pub fn logger(&self) -> Arc<dyn Logger> {
self.logger.read().unwrap().clone()
}
pub fn next_operation_id(&self) -> String {
self.id_generator.next_id()
}
pub fn next_operation_identifier(&self, name: Option<String>) -> OperationIdentifier {
OperationIdentifier::new(self.next_operation_id(), self.parent_id.clone(), name)
}
pub fn current_step_counter(&self) -> u64 {
self.id_generator.current_counter()
}
pub fn create_log_info(&self) -> LogInfo {
let mut info = LogInfo::new(self.durable_execution_arn());
if let Some(ref parent_id) = self.parent_id {
info = info.with_parent_id(parent_id);
}
info = info.with_replay(self.state.is_replay());
info
}
pub fn create_log_info_with_operation(&self, operation_id: &str) -> LogInfo {
self.create_log_info().with_operation_id(operation_id)
}
pub fn create_log_info_with_replay(&self, operation_id: &str, is_replay: bool) -> LogInfo {
let mut info = LogInfo::new(self.durable_execution_arn());
if let Some(ref parent_id) = self.parent_id {
info = info.with_parent_id(parent_id);
}
info.with_operation_id(operation_id).with_replay(is_replay)
}
pub fn log_info(&self, message: &str) {
self.log_with_level(LogLevel::Info, message, &[]);
}
pub fn log_info_with(&self, message: &str, fields: &[(&str, &str)]) {
self.log_with_level(LogLevel::Info, message, fields);
}
pub fn log_debug(&self, message: &str) {
self.log_with_level(LogLevel::Debug, message, &[]);
}
pub fn log_debug_with(&self, message: &str, fields: &[(&str, &str)]) {
self.log_with_level(LogLevel::Debug, message, fields);
}
pub fn log_warn(&self, message: &str) {
self.log_with_level(LogLevel::Warn, message, &[]);
}
pub fn log_warn_with(&self, message: &str, fields: &[(&str, &str)]) {
self.log_with_level(LogLevel::Warn, message, fields);
}
pub fn log_error(&self, message: &str) {
self.log_with_level(LogLevel::Error, message, &[]);
}
pub fn log_error_with(&self, message: &str, fields: &[(&str, &str)]) {
self.log_with_level(LogLevel::Error, message, fields);
}
fn log_with_level(&self, level: LogLevel, message: &str, extra: &[(&str, &str)]) {
let mut log_info = self.create_log_info();
for (key, value) in extra {
log_info = log_info.with_extra(*key, *value);
}
let logger = self.logger.read().unwrap();
match level {
LogLevel::Debug => logger.debug(message, &log_info),
LogLevel::Info => logger.info(message, &log_info),
LogLevel::Warn => logger.warn(message, &log_info),
LogLevel::Error => logger.error(message, &log_info),
}
}
pub fn get_original_input<T>(&self) -> DurableResult<T>
where
T: serde::de::DeserializeOwned,
{
let input_payload = self.state.get_original_input_raw().ok_or_else(|| {
crate::error::DurableError::Validation {
message: "No original input available. The EXECUTION operation may not exist or has no input payload.".to_string(),
}
})?;
serde_json::from_str(input_payload).map_err(|e| crate::error::DurableError::SerDes {
message: format!("Failed to deserialize original input: {}", e),
})
}
pub fn get_original_input_raw(&self) -> Option<&str> {
self.state.get_original_input_raw()
}
pub async fn complete_execution_success<T>(&self, result: &T) -> DurableResult<()>
where
T: serde::Serialize,
{
let serialized =
serde_json::to_string(result).map_err(|e| crate::error::DurableError::SerDes {
message: format!("Failed to serialize execution result: {}", e),
})?;
self.state
.complete_execution_success(Some(serialized))
.await
}
pub async fn complete_execution_failure(
&self,
error: crate::error::ErrorObject,
) -> DurableResult<()> {
self.state.complete_execution_failure(error).await
}
pub async fn complete_execution_if_large<T>(&self, result: &T) -> DurableResult<bool>
where
T: serde::Serialize,
{
if crate::lambda::DurableExecutionInvocationOutput::would_exceed_max_size(result) {
self.complete_execution_success(result).await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn step<T, F, Fut>(
&self,
func: F,
config: Option<crate::config::StepConfig>,
) -> DurableResult<T>
where
T: DurableValue,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
+ Send,
{
let op_id = self.next_operation_identifier(None);
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::step_handler(func, &self.state, &op_id, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn step_named<T, F, Fut>(
&self,
name: &str,
func: F,
config: Option<crate::config::StepConfig>,
) -> DurableResult<T>
where
T: DurableValue,
F: FnOnce(StepContext) -> Fut + Send,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
+ Send,
{
let op_id = self.next_operation_identifier(Some(name.to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::step_handler(func, &self.state, &op_id, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn wait(
&self,
duration: crate::duration::Duration,
name: Option<&str>,
) -> DurableResult<()> {
let op_id = self.next_operation_identifier(name.map(|s| s.to_string()));
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::wait_handler(duration, &self.state, &op_id, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn cancel_wait(&self, operation_id: &str) -> DurableResult<()> {
let logger = self.logger.read().unwrap().clone();
crate::handlers::wait_cancel_handler(&self.state, operation_id, &logger).await
}
pub async fn create_callback<T>(
&self,
config: Option<crate::config::CallbackConfig>,
) -> DurableResult<crate::handlers::Callback<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
let op_id = self.next_operation_identifier(None);
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::callback_handler(&self.state, &op_id, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn create_callback_named<T>(
&self,
name: &str,
config: Option<crate::config::CallbackConfig>,
) -> DurableResult<crate::handlers::Callback<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
let op_id = self.next_operation_identifier(Some(name.to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::callback_handler(&self.state, &op_id, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn invoke<P, R>(
&self,
function_name: &str,
payload: P,
config: Option<crate::config::InvokeConfig<P, R>>,
) -> DurableResult<R>
where
P: serde::Serialize + serde::de::DeserializeOwned + Send,
R: serde::Serialize + serde::de::DeserializeOwned + Send,
{
let op_id = self.next_operation_identifier(Some(format!("invoke:{}", function_name)));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::invoke_handler(
function_name,
payload,
&self.state,
&op_id,
&config,
&logger,
)
.await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn map<T, U, F, Fut>(
&self,
items: Vec<T>,
func: F,
config: Option<crate::config::MapConfig>,
) -> DurableResult<crate::concurrency::BatchResult<U>>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync + Clone + 'static,
U: serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
F: Fn(DurableContext, T, usize) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = DurableResult<U>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("map".to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::map_handler(items, func, &self.state, &op_id, self, &config, &logger)
.await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn parallel<T, F, Fut>(
&self,
branches: Vec<F>,
config: Option<crate::config::ParallelConfig>,
) -> DurableResult<crate::concurrency::BatchResult<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + 'static,
F: FnOnce(DurableContext) -> Fut + Send + 'static,
Fut: std::future::Future<Output = DurableResult<T>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("parallel".to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::parallel_handler(
branches,
&self.state,
&op_id,
self,
&config,
&logger,
)
.await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn run_in_child_context<T, F, Fut>(
&self,
func: F,
config: Option<crate::config::ChildConfig>,
) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send,
F: FnOnce(DurableContext) -> Fut + Send,
Fut: std::future::Future<Output = DurableResult<T>> + Send,
{
let op_id = self.next_operation_identifier(Some("child_context".to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::child_handler(func, &self.state, &op_id, self, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn run_in_child_context_named<T, F, Fut>(
&self,
name: &str,
func: F,
config: Option<crate::config::ChildConfig>,
) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send,
F: FnOnce(DurableContext) -> Fut + Send,
Fut: std::future::Future<Output = DurableResult<T>> + Send,
{
let op_id = self.next_operation_identifier(Some(name.to_string()));
let config = config.unwrap_or_default();
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::child_handler(func, &self.state, &op_id, self, &config, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn wait_for_condition<T, S, F, Fut>(
&self,
check: F,
config: WaitForConditionConfig<S>,
) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send,
S: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync,
F: Fn(&S, &WaitForConditionContext) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<T, Box<dyn std::error::Error + Send + Sync>>>
+ Send,
{
let op_id = self.next_operation_identifier(Some("wait_for_condition".to_string()));
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::wait_for_condition_handler(
check,
config,
&self.state,
&op_id,
&logger,
)
.await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn wait_for_callback<T, F, Fut>(
&self,
submitter: F,
config: Option<crate::config::CallbackConfig>,
) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Sync,
F: FnOnce(String) -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>
+ Send
+ 'static,
{
let callback: crate::handlers::Callback<T> = self.create_callback(config).await?;
let callback_id = callback.callback_id.clone();
let op_id = self.next_operation_identifier(Some("wait_for_callback_submitter".to_string()));
let child_config = crate::config::ChildConfig::default();
let logger = self.logger.read().unwrap().clone();
crate::handlers::child_handler(
|child_ctx| {
let callback_id = callback_id.clone();
async move {
child_ctx
.step_named(
"execute_submitter",
move |_| async move {
Ok(())
},
None,
)
.await?;
submitter(callback_id).await.map_err(|e| {
crate::error::DurableError::UserCode {
message: e.to_string(),
error_type: "SubmitterError".to_string(),
stack_trace: None,
}
})?;
Ok(())
}
},
&self.state,
&op_id,
self,
&child_config,
&logger,
)
.await?;
self.state.track_replay(&op_id.operation_id).await;
callback.result().await
}
pub async fn all<T, Fut>(&self, futures: Vec<Fut>) -> DurableResult<Vec<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Clone + 'static,
Fut: std::future::Future<Output = DurableResult<T>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("all".to_string()));
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::all_handler(futures, &self.state, &op_id, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn all_settled<T, Fut>(
&self,
futures: Vec<Fut>,
) -> DurableResult<crate::concurrency::BatchResult<T>>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Clone + 'static,
Fut: std::future::Future<Output = DurableResult<T>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("all_settled".to_string()));
let logger = self.logger.read().unwrap().clone();
let result =
crate::handlers::all_settled_handler(futures, &self.state, &op_id, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn race<T, Fut>(&self, futures: Vec<Fut>) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Clone + 'static,
Fut: std::future::Future<Output = DurableResult<T>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("race".to_string()));
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::race_handler(futures, &self.state, &op_id, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
pub async fn any<T, Fut>(&self, futures: Vec<Fut>) -> DurableResult<T>
where
T: serde::Serialize + serde::de::DeserializeOwned + Send + Clone + 'static,
Fut: std::future::Future<Output = DurableResult<T>> + Send + 'static,
{
let op_id = self.next_operation_identifier(Some("any".to_string()));
let logger = self.logger.read().unwrap().clone();
let result = crate::handlers::any_handler(futures, &self.state, &op_id, &logger).await;
if result.is_ok() {
self.state.track_replay(&op_id.operation_id).await;
}
result
}
}
#[allow(clippy::type_complexity)]
pub struct WaitForConditionConfig<S> {
pub initial_state: S,
pub wait_strategy: Box<dyn Fn(&S, usize) -> crate::config::WaitDecision + Send + Sync>,
pub timeout: Option<crate::duration::Duration>,
pub serdes: Option<std::sync::Arc<dyn crate::config::SerDesAny>>,
}
impl<S> std::fmt::Debug for WaitForConditionConfig<S>
where
S: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WaitForConditionConfig")
.field("initial_state", &self.initial_state)
.field("wait_strategy", &"<fn>")
.field("timeout", &self.timeout)
.field("serdes", &self.serdes.is_some())
.finish()
}
}
impl<S> WaitForConditionConfig<S> {
pub fn from_interval(
initial_state: S,
interval: crate::duration::Duration,
max_attempts: Option<usize>,
) -> Self
where
S: Send + Sync + 'static,
{
let interval_secs = interval.to_seconds();
let max = max_attempts.unwrap_or(usize::MAX);
Self {
initial_state,
wait_strategy: Box::new(move |_state: &S, attempts_made: usize| {
if attempts_made >= max {
return crate::config::WaitDecision::Done;
}
crate::config::WaitDecision::Continue {
delay: crate::duration::Duration::from_seconds(interval_secs),
}
}),
timeout: None,
serdes: None,
}
}
}
impl<S: Default + Send + Sync + 'static> Default for WaitForConditionConfig<S> {
fn default() -> Self {
Self::from_interval(
S::default(),
crate::duration::Duration::from_seconds(5),
None,
)
}
}
#[derive(Debug, Clone)]
pub struct WaitForConditionContext {
pub attempt: usize,
pub max_attempts: Option<usize>,
}
impl Clone for DurableContext {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
lambda_context: self.lambda_context.clone(),
parent_id: self.parent_id.clone(),
id_generator: self.id_generator.clone(),
logger: self.logger.clone(),
}
}
}
impl std::fmt::Debug for DurableContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DurableContext")
.field("durable_execution_arn", &self.durable_execution_arn())
.field("parent_id", &self.parent_id)
.field("step_counter", &self.current_step_counter())
.finish_non_exhaustive()
}
}
mod hex {
const HEX_CHARS: &[u8; 16] = b"0123456789abcdef";
pub fn encode(bytes: &[u8]) -> String {
let mut result = String::with_capacity(bytes.len() * 2);
for &byte in bytes {
result.push(HEX_CHARS[(byte >> 4) as usize] as char);
result.push(HEX_CHARS[(byte & 0x0f) as usize] as char);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operation_identifier_new() {
let id = OperationIdentifier::new(
"op-123",
Some("parent-456".to_string()),
Some("my-step".to_string()),
);
assert_eq!(id.operation_id, "op-123");
assert_eq!(id.parent_id, Some("parent-456".to_string()));
assert_eq!(id.name, Some("my-step".to_string()));
}
#[test]
fn test_operation_identifier_root() {
let id = OperationIdentifier::root("op-123");
assert_eq!(id.operation_id, "op-123");
assert!(id.parent_id.is_none());
assert!(id.name.is_none());
}
#[test]
fn test_operation_identifier_with_parent() {
let id = OperationIdentifier::with_parent("op-123", "parent-456");
assert_eq!(id.operation_id, "op-123");
assert_eq!(id.parent_id, Some("parent-456".to_string()));
assert!(id.name.is_none());
}
#[test]
fn test_operation_identifier_with_name() {
let id = OperationIdentifier::root("op-123").with_name("my-step");
assert_eq!(id.name, Some("my-step".to_string()));
}
#[test]
fn test_operation_identifier_display() {
let id = OperationIdentifier::root("op-123");
assert_eq!(format!("{}", id), "op-123");
let id_with_name = OperationIdentifier::root("op-123").with_name("my-step");
assert_eq!(format!("{}", id_with_name), "my-step(op-123)");
}
#[test]
fn test_generate_operation_id_deterministic() {
let id1 = generate_operation_id("base-123", 0);
let id2 = generate_operation_id("base-123", 0);
assert_eq!(id1, id2);
}
#[test]
fn test_generate_operation_id_different_counters() {
let id1 = generate_operation_id("base-123", 0);
let id2 = generate_operation_id("base-123", 1);
assert_ne!(id1, id2);
}
#[test]
fn test_generate_operation_id_different_bases() {
let id1 = generate_operation_id("base-123", 0);
let id2 = generate_operation_id("base-456", 0);
assert_ne!(id1, id2);
}
#[test]
fn test_generate_operation_id_length() {
let id = generate_operation_id("base-123", 0);
assert_eq!(id.len(), 32); }
#[test]
fn test_operation_id_generator_new() {
let gen = OperationIdGenerator::new("base-123");
assert_eq!(gen.base_id(), "base-123");
assert_eq!(gen.current_counter(), 0);
}
#[test]
fn test_operation_id_generator_with_counter() {
let gen = OperationIdGenerator::with_counter("base-123", 10);
assert_eq!(gen.current_counter(), 10);
}
#[test]
fn test_operation_id_generator_next_id() {
let gen = OperationIdGenerator::new("base-123");
let id1 = gen.next_id();
assert_eq!(gen.current_counter(), 1);
let id2 = gen.next_id();
assert_eq!(gen.current_counter(), 2);
assert_ne!(id1, id2);
}
#[test]
fn test_operation_id_generator_id_for_counter() {
let gen = OperationIdGenerator::new("base-123");
let id_0 = gen.id_for_counter(0);
let id_1 = gen.id_for_counter(1);
assert_eq!(gen.current_counter(), 0);
let next = gen.next_id();
assert_eq!(next, id_0);
let next = gen.next_id();
assert_eq!(next, id_1);
}
#[test]
fn test_operation_id_generator_create_child() {
let gen = OperationIdGenerator::new("base-123");
gen.next_id();
let child = gen.create_child("child-op-id");
assert_eq!(child.base_id(), "child-op-id");
assert_eq!(child.current_counter(), 0);
}
#[test]
fn test_operation_id_generator_clone() {
let gen = OperationIdGenerator::new("base-123");
gen.next_id();
gen.next_id();
let cloned = gen.clone();
assert_eq!(cloned.base_id(), gen.base_id());
assert_eq!(cloned.current_counter(), gen.current_counter());
}
#[test]
fn test_operation_id_generator_thread_safety() {
use std::thread;
let gen = Arc::new(OperationIdGenerator::new("base-123"));
let mut handles = vec![];
for _ in 0..10 {
let gen_clone = gen.clone();
handles.push(thread::spawn(move || {
let mut ids = vec![];
for _ in 0..100 {
ids.push(gen_clone.next_id());
}
ids
}));
}
let mut all_ids = vec![];
for handle in handles {
all_ids.extend(handle.join().unwrap());
}
let unique_count = {
let mut set = std::collections::HashSet::new();
for id in &all_ids {
set.insert(id.clone());
}
set.len()
};
assert_eq!(unique_count, 1000);
assert_eq!(gen.current_counter(), 1000);
}
#[test]
fn test_log_info_new() {
let info = LogInfo::new("arn:test");
assert_eq!(info.durable_execution_arn, Some("arn:test".to_string()));
assert!(info.operation_id.is_none());
assert!(info.parent_id.is_none());
}
#[test]
fn test_log_info_with_operation_id() {
let info = LogInfo::new("arn:test").with_operation_id("op-123");
assert_eq!(info.operation_id, Some("op-123".to_string()));
}
#[test]
fn test_log_info_with_parent_id() {
let info = LogInfo::new("arn:test").with_parent_id("parent-456");
assert_eq!(info.parent_id, Some("parent-456".to_string()));
}
#[test]
fn test_log_info_with_extra() {
let info = LogInfo::new("arn:test")
.with_extra("key1", "value1")
.with_extra("key2", "value2");
assert_eq!(info.extra.len(), 2);
assert_eq!(info.extra[0], ("key1".to_string(), "value1".to_string()));
}
#[test]
fn test_hex_encode() {
assert_eq!(hex::encode(&[0x00]), "00");
assert_eq!(hex::encode(&[0xff]), "ff");
assert_eq!(hex::encode(&[0x12, 0x34, 0xab, 0xcd]), "1234abcd");
}
}
#[cfg(test)]
mod durable_context_tests {
use super::*;
use crate::client::SharedDurableServiceClient;
use crate::lambda::InitialExecutionState;
use std::sync::Arc;
#[cfg(test)]
fn create_mock_client() -> SharedDurableServiceClient {
use crate::client::MockDurableServiceClient;
Arc::new(MockDurableServiceClient::new())
}
fn create_test_state() -> Arc<ExecutionState> {
let client = create_mock_client();
Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"token-123",
InitialExecutionState::new(),
client,
))
}
#[test]
fn test_durable_context_new() {
let state = create_test_state();
let ctx = DurableContext::new(state.clone());
assert_eq!(
ctx.durable_execution_arn(),
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123"
);
assert!(ctx.parent_id().is_none());
assert!(ctx.lambda_context().is_none());
assert_eq!(ctx.current_step_counter(), 0);
}
#[test]
fn test_durable_context_next_operation_id() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let id1 = ctx.next_operation_id();
let id2 = ctx.next_operation_id();
assert_ne!(id1, id2);
assert_eq!(ctx.current_step_counter(), 2);
}
#[test]
fn test_durable_context_next_operation_identifier() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let op_id = ctx.next_operation_identifier(Some("my-step".to_string()));
assert!(op_id.parent_id.is_none());
assert_eq!(op_id.name, Some("my-step".to_string()));
assert!(!op_id.operation_id.is_empty());
}
#[test]
fn test_durable_context_create_child_context() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id);
assert_eq!(child_ctx.parent_id(), Some(parent_op_id.as_str()));
assert_eq!(child_ctx.current_step_counter(), 0);
assert_eq!(
child_ctx.durable_execution_arn(),
ctx.durable_execution_arn()
);
}
#[test]
fn test_durable_context_child_generates_different_ids() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id);
let child_id = child_ctx.next_operation_id();
let parent_id = ctx.next_operation_id();
assert_ne!(child_id, parent_id);
}
#[test]
fn test_durable_context_child_operation_identifier_has_parent() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id);
let child_op_id = child_ctx.next_operation_identifier(None);
assert_eq!(child_op_id.parent_id, Some(parent_op_id));
}
#[test]
fn test_durable_context_set_logger() {
let state = create_test_state();
let mut ctx = DurableContext::new(state);
let custom_logger: Arc<dyn Logger> = Arc::new(TracingLogger);
ctx.set_logger(custom_logger);
let _ = ctx.logger();
}
#[test]
fn test_durable_context_with_logger() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let custom_logger: Arc<dyn Logger> = Arc::new(TracingLogger);
let ctx_with_logger = ctx.with_logger(custom_logger);
let _ = ctx_with_logger.logger();
}
#[test]
fn test_durable_context_create_log_info() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let info = ctx.create_log_info();
assert_eq!(
info.durable_execution_arn,
Some("arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123".to_string())
);
assert!(info.parent_id.is_none());
}
#[test]
fn test_durable_context_create_log_info_with_parent() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id);
let info = child_ctx.create_log_info();
assert_eq!(info.parent_id, Some(parent_op_id));
}
#[test]
fn test_durable_context_create_log_info_with_operation() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let info = ctx.create_log_info_with_operation("op-123");
assert_eq!(info.operation_id, Some("op-123".to_string()));
}
#[test]
fn test_log_info_with_replay() {
let info = LogInfo::new("arn:test")
.with_operation_id("op-123")
.with_replay(true);
assert!(info.is_replay);
assert_eq!(info.operation_id, Some("op-123".to_string()));
}
#[test]
fn test_log_info_default_not_replay() {
let info = LogInfo::default();
assert!(!info.is_replay);
}
#[test]
fn test_replay_logging_config_default() {
let config = ReplayLoggingConfig::default();
assert_eq!(config, ReplayLoggingConfig::SuppressAll);
}
#[test]
fn test_replay_aware_logger_suppress_all() {
use std::sync::atomic::{AtomicUsize, Ordering};
let debug_count = Arc::new(AtomicUsize::new(0));
let info_count = Arc::new(AtomicUsize::new(0));
let warn_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
{
let count = debug_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = info_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = warn_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = error_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::new(inner, ReplayLoggingConfig::SuppressAll);
let non_replay_info = LogInfo::new("arn:test").with_replay(false);
logger.debug("test", &non_replay_info);
logger.info("test", &non_replay_info);
logger.warn("test", &non_replay_info);
logger.error("test", &non_replay_info);
assert_eq!(debug_count.load(Ordering::SeqCst), 1);
assert_eq!(info_count.load(Ordering::SeqCst), 1);
assert_eq!(warn_count.load(Ordering::SeqCst), 1);
assert_eq!(error_count.load(Ordering::SeqCst), 1);
let replay_info = LogInfo::new("arn:test").with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(debug_count.load(Ordering::SeqCst), 1);
assert_eq!(info_count.load(Ordering::SeqCst), 1);
assert_eq!(warn_count.load(Ordering::SeqCst), 1);
assert_eq!(error_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_replay_aware_logger_allow_all() {
use std::sync::atomic::{AtomicUsize, Ordering};
let call_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
{
let count = call_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = call_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = call_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = call_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::allow_all(inner);
let replay_info = LogInfo::new("arn:test").with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(call_count.load(Ordering::SeqCst), 4);
}
#[test]
fn test_replay_aware_logger_errors_only() {
use std::sync::atomic::{AtomicUsize, Ordering};
let debug_count = Arc::new(AtomicUsize::new(0));
let info_count = Arc::new(AtomicUsize::new(0));
let warn_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
{
let count = debug_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = info_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = warn_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = error_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::new(inner, ReplayLoggingConfig::ErrorsOnly);
let replay_info = LogInfo::new("arn:test").with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(debug_count.load(Ordering::SeqCst), 0);
assert_eq!(info_count.load(Ordering::SeqCst), 0);
assert_eq!(warn_count.load(Ordering::SeqCst), 0);
assert_eq!(error_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_replay_aware_logger_warnings_and_errors() {
use std::sync::atomic::{AtomicUsize, Ordering};
let debug_count = Arc::new(AtomicUsize::new(0));
let info_count = Arc::new(AtomicUsize::new(0));
let warn_count = Arc::new(AtomicUsize::new(0));
let error_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
{
let count = debug_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = info_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = warn_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = error_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::new(inner, ReplayLoggingConfig::WarningsAndErrors);
let replay_info = LogInfo::new("arn:test").with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(debug_count.load(Ordering::SeqCst), 0);
assert_eq!(info_count.load(Ordering::SeqCst), 0);
assert_eq!(warn_count.load(Ordering::SeqCst), 1);
assert_eq!(error_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_replay_aware_logger_suppress_replay_constructor() {
let inner: Arc<dyn Logger> = Arc::new(TracingLogger);
let logger = ReplayAwareLogger::suppress_replay(inner);
assert_eq!(logger.config(), ReplayLoggingConfig::SuppressAll);
}
#[test]
fn test_replay_aware_logger_debug() {
let inner: Arc<dyn Logger> = Arc::new(TracingLogger);
let logger = ReplayAwareLogger::new(inner, ReplayLoggingConfig::SuppressAll);
let debug_str = format!("{:?}", logger);
assert!(debug_str.contains("ReplayAwareLogger"));
assert!(debug_str.contains("SuppressAll"));
}
#[test]
fn test_durable_context_create_log_info_with_replay_method() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let info = ctx.create_log_info_with_replay("op-123", true);
assert_eq!(info.operation_id, Some("op-123".to_string()));
assert!(info.is_replay);
}
#[test]
fn test_durable_context_clone() {
let state = create_test_state();
let ctx = DurableContext::new(state);
ctx.next_operation_id();
ctx.next_operation_id();
let cloned = ctx.clone();
assert_eq!(cloned.durable_execution_arn(), ctx.durable_execution_arn());
assert_eq!(cloned.current_step_counter(), ctx.current_step_counter());
}
#[test]
fn test_durable_context_debug() {
let state = create_test_state();
let ctx = DurableContext::new(state);
let debug_str = format!("{:?}", ctx);
assert!(debug_str.contains("DurableContext"));
assert!(debug_str.contains("durable_execution_arn"));
}
#[test]
fn test_durable_context_state_access() {
let state = create_test_state();
let ctx = DurableContext::new(state.clone());
assert!(Arc::ptr_eq(ctx.state(), &state));
}
#[test]
fn test_durable_context_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<DurableContext>();
}
#[test]
fn test_log_info_method() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
let info_count = Arc::new(AtomicUsize::new(0));
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
|_, _| {},
{
let count = info_count.clone();
move |_, info: &LogInfo| {
count.fetch_add(1, Ordering::SeqCst);
*captured_info_clone.lock().unwrap() = Some(info.clone());
}
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_info("Test message");
assert_eq!(info_count.load(Ordering::SeqCst), 1);
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(
info.durable_execution_arn,
Some("arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123".to_string())
);
}
#[test]
fn test_log_debug_method() {
use std::sync::atomic::{AtomicUsize, Ordering};
let debug_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
{
let count = debug_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_debug("Debug message");
assert_eq!(debug_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_log_warn_method() {
use std::sync::atomic::{AtomicUsize, Ordering};
let warn_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(
|_, _| {},
|_, _| {},
{
let count = warn_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_warn("Warning message");
assert_eq!(warn_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_log_error_method() {
use std::sync::atomic::{AtomicUsize, Ordering};
let error_count = Arc::new(AtomicUsize::new(0));
let inner = Arc::new(custom_logger(|_, _| {}, |_, _| {}, |_, _| {}, {
let count = error_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
}));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_error("Error message");
assert_eq!(error_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_log_info_with_extra_fields() {
use std::sync::Mutex;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
|_, _| {},
move |_, info: &LogInfo| {
*captured_info_clone.lock().unwrap() = Some(info.clone());
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_info_with("Test message", &[("order_id", "123"), ("amount", "99.99")]);
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(info.extra.len(), 2);
assert!(info
.extra
.contains(&("order_id".to_string(), "123".to_string())));
assert!(info
.extra
.contains(&("amount".to_string(), "99.99".to_string())));
}
#[test]
fn test_log_debug_with_extra_fields() {
use std::sync::Mutex;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
move |_, info: &LogInfo| {
*captured_info_clone.lock().unwrap() = Some(info.clone());
},
|_, _| {},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_debug_with("Debug message", &[("key", "value")]);
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(info.extra.len(), 1);
assert!(info
.extra
.contains(&("key".to_string(), "value".to_string())));
}
#[test]
fn test_log_warn_with_extra_fields() {
use std::sync::Mutex;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
|_, _| {},
|_, _| {},
move |_, info: &LogInfo| {
*captured_info_clone.lock().unwrap() = Some(info.clone());
},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_warn_with("Warning message", &[("retry", "3")]);
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(info.extra.len(), 1);
assert!(info.extra.contains(&("retry".to_string(), "3".to_string())));
}
#[test]
fn test_log_error_with_extra_fields() {
use std::sync::Mutex;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
|_, _| {},
|_, _| {},
|_, _| {},
move |_, info: &LogInfo| {
*captured_info_clone.lock().unwrap() = Some(info.clone());
},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner);
ctx.log_error_with(
"Error message",
&[("error_code", "E001"), ("details", "Something went wrong")],
);
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(info.extra.len(), 2);
assert!(info
.extra
.contains(&("error_code".to_string(), "E001".to_string())));
assert!(info
.extra
.contains(&("details".to_string(), "Something went wrong".to_string())));
}
#[test]
fn test_log_methods_include_parent_id_in_child_context() {
use std::sync::Mutex;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
move |_, info: &LogInfo| {
*captured_info_clone.lock().unwrap() = Some(info.clone());
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(inner.clone());
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id).with_logger(inner);
child_ctx.log_info("Child context message");
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().unwrap();
assert_eq!(info.parent_id, Some(parent_op_id));
}
#[test]
fn test_configure_logger_swaps_logger() {
use std::sync::atomic::{AtomicUsize, Ordering};
let original_count = Arc::new(AtomicUsize::new(0));
let new_count = Arc::new(AtomicUsize::new(0));
let original_logger: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
{
let count = original_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
));
let new_logger: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
{
let count = new_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(original_logger);
ctx.log_info("before swap");
assert_eq!(original_count.load(Ordering::SeqCst), 1);
assert_eq!(new_count.load(Ordering::SeqCst), 0);
ctx.configure_logger(new_logger);
ctx.log_info("after swap");
assert_eq!(original_count.load(Ordering::SeqCst), 1); assert_eq!(new_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_original_logger_used_when_configure_logger_not_called() {
use std::sync::atomic::{AtomicUsize, Ordering};
let original_count = Arc::new(AtomicUsize::new(0));
let original_logger: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
{
let count = original_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(original_logger);
ctx.log_info("message 1");
ctx.log_info("message 2");
ctx.log_info("message 3");
assert_eq!(original_count.load(Ordering::SeqCst), 3);
}
#[test]
fn test_configure_logger_affects_child_contexts() {
use std::sync::atomic::{AtomicUsize, Ordering};
let original_count = Arc::new(AtomicUsize::new(0));
let new_count = Arc::new(AtomicUsize::new(0));
let original_logger: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
{
let count = original_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
));
let new_logger: Arc<dyn Logger> = Arc::new(custom_logger(
|_, _| {},
{
let count = new_count.clone();
move |_, _| {
count.fetch_add(1, Ordering::SeqCst);
}
},
|_, _| {},
|_, _| {},
));
let state = create_test_state();
let ctx = DurableContext::new(state).with_logger(original_logger);
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id);
child_ctx.log_info("child before swap");
assert_eq!(original_count.load(Ordering::SeqCst), 1);
ctx.configure_logger(new_logger);
child_ctx.log_info("child after swap");
assert_eq!(new_count.load(Ordering::SeqCst), 1);
assert_eq!(original_count.load(Ordering::SeqCst), 1); }
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_operation_id_determinism(
base_id in "[a-zA-Z0-9:/-]{1,100}",
counter in 0u64..10000u64,
) {
let id1 = generate_operation_id(&base_id, counter);
let id2 = generate_operation_id(&base_id, counter);
prop_assert_eq!(&id1, &id2, "Same base_id and counter must produce identical IDs");
prop_assert_eq!(id1.len(), 32, "ID must be 32 hex characters");
prop_assert!(id1.chars().all(|c| c.is_ascii_hexdigit()), "ID must be valid hex");
}
#[test]
fn prop_operation_id_generator_determinism(
base_id in "[a-zA-Z0-9:/-]{1,100}",
num_ids in 1usize..50usize,
) {
let gen1 = OperationIdGenerator::new(&base_id);
let gen2 = OperationIdGenerator::new(&base_id);
let ids1: Vec<String> = (0..num_ids).map(|_| gen1.next_id()).collect();
let ids2: Vec<String> = (0..num_ids).map(|_| gen2.next_id()).collect();
prop_assert_eq!(&ids1, &ids2, "Same generator sequence must produce identical IDs");
let unique_count = {
let mut set = std::collections::HashSet::new();
for id in &ids1 {
set.insert(id.clone());
}
set.len()
};
prop_assert_eq!(unique_count, num_ids, "All IDs in sequence must be unique");
}
#[test]
fn prop_operation_id_replay_determinism(
base_id in "[a-zA-Z0-9:/-]{1,100}",
operations in prop::collection::vec(0u32..3u32, 1..20),
) {
let gen1 = OperationIdGenerator::new(&base_id);
let gen2 = OperationIdGenerator::new(&base_id);
let mut ids1 = Vec::new();
let mut ids2 = Vec::new();
for op_type in &operations {
ids1.push(gen1.next_id());
if *op_type == 2 {
let parent_id = ids1.last().unwrap().clone();
let child_gen = gen1.create_child(parent_id);
ids1.push(child_gen.next_id());
}
}
for op_type in &operations {
ids2.push(gen2.next_id());
if *op_type == 2 {
let parent_id = ids2.last().unwrap().clone();
let child_gen = gen2.create_child(parent_id);
ids2.push(child_gen.next_id());
}
}
prop_assert_eq!(&ids1, &ids2, "Replay must produce identical operation IDs");
}
}
mod concurrent_id_tests {
use super::*;
use std::sync::Arc;
use std::thread;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_concurrent_id_uniqueness(
base_id in "[a-zA-Z0-9:/-]{1,100}",
num_threads in 2usize..10usize,
ids_per_thread in 10usize..100usize,
) {
let gen = Arc::new(OperationIdGenerator::new(&base_id));
let mut handles = vec![];
for _ in 0..num_threads {
let gen_clone = gen.clone();
let count = ids_per_thread;
handles.push(thread::spawn(move || {
let mut ids = Vec::with_capacity(count);
for _ in 0..count {
ids.push(gen_clone.next_id());
}
ids
}));
}
let mut all_ids = Vec::new();
for handle in handles {
all_ids.extend(handle.join().unwrap());
}
let total_expected = num_threads * ids_per_thread;
prop_assert_eq!(all_ids.len(), total_expected, "Should have generated {} IDs", total_expected);
let unique_count = {
let mut set = std::collections::HashSet::new();
for id in &all_ids {
set.insert(id.clone());
}
set.len()
};
prop_assert_eq!(
unique_count,
total_expected,
"All {} IDs must be unique, but only {} were unique",
total_expected,
unique_count
);
prop_assert_eq!(
gen.current_counter() as usize,
total_expected,
"Counter should equal total IDs generated"
);
}
#[test]
fn prop_concurrent_id_uniqueness_stress(
base_id in "[a-zA-Z0-9:/-]{1,50}",
) {
let num_threads = 20;
let ids_per_thread = 500;
let gen = Arc::new(OperationIdGenerator::new(&base_id));
let mut handles = vec![];
for _ in 0..num_threads {
let gen_clone = gen.clone();
handles.push(thread::spawn(move || {
let mut ids = Vec::with_capacity(ids_per_thread);
for _ in 0..ids_per_thread {
ids.push(gen_clone.next_id());
}
ids
}));
}
let mut all_ids = Vec::new();
for handle in handles {
all_ids.extend(handle.join().unwrap());
}
let total_expected = num_threads * ids_per_thread;
let unique_count = {
let mut set = std::collections::HashSet::new();
for id in &all_ids {
set.insert(id.clone());
}
set.len()
};
prop_assert_eq!(
unique_count,
total_expected,
"All {} IDs must be unique under high concurrency",
total_expected
);
}
}
}
mod logging_automatic_context_tests {
use super::*;
use std::sync::{Arc, Mutex};
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_logging_automatic_context(
message in "[a-zA-Z0-9 ]{1,100}",
log_level in 0u8..4u8,
) {
use crate::client::MockDurableServiceClient;
use crate::lambda::InitialExecutionState;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
));
let client: crate::client::SharedDurableServiceClient = Arc::new(MockDurableServiceClient::new());
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"token-123",
InitialExecutionState::new(),
client,
));
let ctx = DurableContext::new(state).with_logger(inner);
match log_level {
0 => ctx.log_debug(&message),
1 => ctx.log_info(&message),
2 => ctx.log_warn(&message),
_ => ctx.log_error(&message),
}
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().expect("LogInfo should be captured");
prop_assert!(
info.durable_execution_arn.is_some(),
"durable_execution_arn must be automatically included"
);
prop_assert_eq!(
info.durable_execution_arn.as_ref().unwrap(),
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"durable_execution_arn must match the context's ARN"
);
}
#[test]
fn prop_logging_automatic_context_child(
message in "[a-zA-Z0-9 ]{1,100}",
log_level in 0u8..4u8,
) {
use crate::client::MockDurableServiceClient;
use crate::lambda::InitialExecutionState;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner: Arc<dyn Logger> = Arc::new(custom_logger(
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
));
let client: crate::client::SharedDurableServiceClient = Arc::new(MockDurableServiceClient::new());
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"token-123",
InitialExecutionState::new(),
client,
));
let ctx = DurableContext::new(state).with_logger(inner.clone());
let parent_op_id = ctx.next_operation_id();
let child_ctx = ctx.create_child_context(&parent_op_id).with_logger(inner);
match log_level {
0 => child_ctx.log_debug(&message),
1 => child_ctx.log_info(&message),
2 => child_ctx.log_warn(&message),
_ => child_ctx.log_error(&message),
}
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().expect("LogInfo should be captured");
prop_assert!(
info.durable_execution_arn.is_some(),
"durable_execution_arn must be automatically included in child context"
);
prop_assert!(
info.parent_id.is_some(),
"parent_id must be automatically included in child context"
);
prop_assert_eq!(
info.parent_id.as_ref().unwrap(),
&parent_op_id,
"parent_id must match the parent operation ID"
);
}
}
}
mod logging_extra_fields_tests {
use super::*;
use std::sync::{Arc, Mutex};
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_logging_extra_fields(
message in "[a-zA-Z0-9 ]{1,100}",
log_level in 0u8..4u8,
key1 in "[a-zA-Z_][a-zA-Z0-9_]{0,20}",
value1 in "[a-zA-Z0-9]{1,50}",
key2 in "[a-zA-Z_][a-zA-Z0-9_]{0,20}",
value2 in "[a-zA-Z0-9]{1,50}",
) {
use crate::client::MockDurableServiceClient;
use crate::lambda::InitialExecutionState;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
));
let client: crate::client::SharedDurableServiceClient = Arc::new(MockDurableServiceClient::new());
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"token-123",
InitialExecutionState::new(),
client,
));
let ctx = DurableContext::new(state).with_logger(inner);
let fields: Vec<(&str, &str)> = vec![(&key1, &value1), (&key2, &value2)];
match log_level {
0 => ctx.log_debug_with(&message, &fields),
1 => ctx.log_info_with(&message, &fields),
2 => ctx.log_warn_with(&message, &fields),
_ => ctx.log_error_with(&message, &fields),
}
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().expect("LogInfo should be captured");
prop_assert_eq!(
info.extra.len(),
2,
"Extra fields must be included in the log output"
);
prop_assert!(
info.extra.contains(&(key1.clone(), value1.clone())),
"First extra field must be present: {}={}", key1, value1
);
prop_assert!(
info.extra.contains(&(key2.clone(), value2.clone())),
"Second extra field must be present: {}={}", key2, value2
);
}
#[test]
fn prop_logging_extra_fields_empty(
message in "[a-zA-Z0-9 ]{1,100}",
log_level in 0u8..4u8,
) {
use crate::client::MockDurableServiceClient;
use crate::lambda::InitialExecutionState;
let captured_info = Arc::new(Mutex::new(None::<LogInfo>));
let captured_info_clone = captured_info.clone();
let inner = Arc::new(custom_logger(
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
{
let captured = captured_info_clone.clone();
move |_, info: &LogInfo| {
*captured.lock().unwrap() = Some(info.clone());
}
},
));
let client: crate::client::SharedDurableServiceClient = Arc::new(MockDurableServiceClient::new());
let state = Arc::new(ExecutionState::new(
"arn:aws:lambda:us-east-1:123456789012:function:test:durable:abc123",
"token-123",
InitialExecutionState::new(),
client,
));
let ctx = DurableContext::new(state).with_logger(inner);
let empty_fields: &[(&str, &str)] = &[];
match log_level {
0 => ctx.log_debug_with(&message, empty_fields),
1 => ctx.log_info_with(&message, empty_fields),
2 => ctx.log_warn_with(&message, empty_fields),
_ => ctx.log_error_with(&message, empty_fields),
}
let captured = captured_info.lock().unwrap();
let info = captured.as_ref().expect("LogInfo should be captured");
prop_assert!(
info.extra.is_empty(),
"Extra fields should be empty when none are provided"
);
prop_assert!(
info.durable_execution_arn.is_some(),
"durable_execution_arn must still be present even with empty extra fields"
);
}
}
}
}
#[cfg(test)]
mod sealed_trait_tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
mod logger_tests {
use super::*;
#[test]
fn test_tracing_logger_implements_logger() {
let logger: &dyn Logger = &TracingLogger;
let info = LogInfo::default();
logger.debug("test debug", &info);
logger.info("test info", &info);
logger.warn("test warn", &info);
logger.error("test error", &info);
}
#[test]
fn test_replay_aware_logger_implements_logger() {
let inner = Arc::new(TracingLogger);
let logger = ReplayAwareLogger::new(inner, ReplayLoggingConfig::AllowAll);
let logger_ref: &dyn Logger = &logger;
let info = LogInfo::default();
logger_ref.debug("test debug", &info);
logger_ref.info("test info", &info);
logger_ref.warn("test warn", &info);
logger_ref.error("test error", &info);
}
#[test]
fn test_custom_logger_implements_logger() {
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
let logger = custom_logger(
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
);
let logger_ref: &dyn Logger = &logger;
let info = LogInfo::default();
logger_ref.debug("test", &info);
logger_ref.info("test", &info);
logger_ref.warn("test", &info);
logger_ref.error("test", &info);
assert_eq!(call_count.load(Ordering::SeqCst), 4);
}
#[test]
fn test_simple_custom_logger() {
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
let logger = simple_custom_logger(move |_level, _msg, _info| {
count_clone.fetch_add(1, Ordering::SeqCst);
});
let info = LogInfo::default();
logger.debug("test", &info);
logger.info("test", &info);
logger.warn("test", &info);
logger.error("test", &info);
assert_eq!(call_count.load(Ordering::SeqCst), 4);
}
#[test]
fn test_custom_logger_receives_correct_messages() {
let messages = Arc::new(std::sync::Mutex::new(Vec::new()));
let messages_clone = messages.clone();
let logger = simple_custom_logger(move |level, msg, _info| {
messages_clone
.lock()
.unwrap()
.push(format!("[{}] {}", level, msg));
});
let info = LogInfo::default();
logger.debug("debug message", &info);
logger.info("info message", &info);
logger.warn("warn message", &info);
logger.error("error message", &info);
let logged = messages.lock().unwrap();
assert_eq!(logged.len(), 4);
assert_eq!(logged[0], "[DEBUG] debug message");
assert_eq!(logged[1], "[INFO] info message");
assert_eq!(logged[2], "[WARN] warn message");
assert_eq!(logged[3], "[ERROR] error message");
}
#[test]
fn test_custom_logger_receives_log_info() {
let received_info = Arc::new(std::sync::Mutex::new(None));
let info_clone = received_info.clone();
let logger = simple_custom_logger(move |_level, _msg, info| {
*info_clone.lock().unwrap() = Some(info.clone());
});
let info = LogInfo::new("arn:aws:test")
.with_operation_id("op-123")
.with_parent_id("parent-456")
.with_replay(true);
logger.info("test", &info);
let received = received_info.lock().unwrap().clone().unwrap();
assert_eq!(
received.durable_execution_arn,
Some("arn:aws:test".to_string())
);
assert_eq!(received.operation_id, Some("op-123".to_string()));
assert_eq!(received.parent_id, Some("parent-456".to_string()));
assert!(received.is_replay);
}
#[test]
fn test_replay_aware_logger_suppresses_during_replay() {
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
let inner_logger = Arc::new(custom_logger(
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::new(inner_logger, ReplayLoggingConfig::SuppressAll);
let non_replay_info = LogInfo::default().with_replay(false);
logger.debug("test", &non_replay_info);
logger.info("test", &non_replay_info);
logger.warn("test", &non_replay_info);
logger.error("test", &non_replay_info);
assert_eq!(call_count.load(Ordering::SeqCst), 4);
let replay_info = LogInfo::default().with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(call_count.load(Ordering::SeqCst), 4); }
#[test]
fn test_replay_aware_logger_errors_only_during_replay() {
let call_count = Arc::new(AtomicUsize::new(0));
let count_clone = call_count.clone();
let inner_logger = Arc::new(custom_logger(
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
{
let count = count_clone.clone();
move |_msg, _info| {
count.fetch_add(1, Ordering::SeqCst);
}
},
));
let logger = ReplayAwareLogger::new(inner_logger, ReplayLoggingConfig::ErrorsOnly);
let replay_info = LogInfo::default().with_replay(true);
logger.debug("test", &replay_info);
logger.info("test", &replay_info);
logger.warn("test", &replay_info);
logger.error("test", &replay_info);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
}
}