use std::collections::HashMap;
use std::sync::Arc;
use aws_sdk_lambda::types::Operation;
use aws_sdk_lambda::types::OperationUpdate;
use crate::backend::DurableBackend;
use crate::error::DurableError;
use crate::replay::ReplayEngine;
use crate::types::{CompensationRecord, ExecutionMode};
pub struct DurableContext {
backend: Arc<dyn DurableBackend>,
replay_engine: ReplayEngine,
durable_execution_arn: String,
checkpoint_token: String,
parent_op_id: Option<String>,
batch_mode: bool,
pending_updates: Vec<OperationUpdate>,
compensations: Vec<CompensationRecord>,
}
const PAGE_SIZE: i32 = 1000;
impl DurableContext {
pub async fn new(
backend: Arc<dyn DurableBackend>,
arn: String,
checkpoint_token: String,
initial_operations: Vec<Operation>,
next_marker: Option<String>,
) -> Result<Self, DurableError> {
let mut operations: HashMap<String, Operation> = initial_operations
.into_iter()
.map(|op| (op.id().to_string(), op))
.collect();
let mut marker = next_marker;
while let Some(ref m) = marker {
if m.is_empty() {
break;
}
let response = backend
.get_execution_state(&arn, &checkpoint_token, m, PAGE_SIZE)
.await?;
for op in response.operations() {
operations.insert(op.id().to_string(), op.clone());
}
marker = response.next_marker().map(|s| s.to_string());
}
let replay_engine = ReplayEngine::new(operations, None);
Ok(Self {
backend,
replay_engine,
durable_execution_arn: arn,
checkpoint_token,
parent_op_id: None,
batch_mode: false,
pending_updates: Vec::new(),
compensations: Vec::new(),
})
}
pub fn execution_mode(&self) -> ExecutionMode {
self.replay_engine.execution_mode()
}
pub fn is_replaying(&self) -> bool {
self.replay_engine.is_replaying()
}
pub fn arn(&self) -> &str {
&self.durable_execution_arn
}
pub fn checkpoint_token(&self) -> &str {
&self.checkpoint_token
}
pub fn set_checkpoint_token(&mut self, token: String) {
self.checkpoint_token = token;
}
pub fn backend(&self) -> &Arc<dyn DurableBackend> {
&self.backend
}
pub fn replay_engine_mut(&mut self) -> &mut ReplayEngine {
&mut self.replay_engine
}
pub fn create_child_context(&self, parent_op_id: &str) -> DurableContext {
let operations = self.replay_engine.operations().clone();
let replay_engine = ReplayEngine::new(operations, Some(parent_op_id.to_string()));
DurableContext {
backend: self.backend.clone(),
replay_engine,
durable_execution_arn: self.durable_execution_arn.clone(),
checkpoint_token: self.checkpoint_token.clone(),
parent_op_id: Some(parent_op_id.to_string()),
batch_mode: false,
pending_updates: Vec::new(),
compensations: Vec::new(), }
}
pub fn replay_engine(&self) -> &ReplayEngine {
&self.replay_engine
}
pub fn parent_op_id(&self) -> Option<&str> {
self.parent_op_id.as_deref()
}
pub fn enable_batch_mode(&mut self) {
self.batch_mode = true;
}
pub fn is_batch_mode(&self) -> bool {
self.batch_mode
}
pub(crate) fn push_pending_update(&mut self, update: OperationUpdate) {
self.pending_updates.push(update);
}
pub fn pending_update_count(&self) -> usize {
self.pending_updates.len()
}
pub fn compensation_count(&self) -> usize {
self.compensations.len()
}
pub(crate) fn push_compensation(&mut self, record: CompensationRecord) {
self.compensations.push(record);
}
pub(crate) fn take_compensations(&mut self) -> Vec<CompensationRecord> {
std::mem::take(&mut self.compensations)
}
pub async fn flush_batch(&mut self) -> Result<(), DurableError> {
if self.pending_updates.is_empty() {
return Ok(());
}
let updates = std::mem::take(&mut self.pending_updates);
let response = self
.backend()
.batch_checkpoint(self.arn(), self.checkpoint_token(), updates, None)
.await?;
let new_token = response.checkpoint_token().ok_or_else(|| {
DurableError::checkpoint_failed(
"batch",
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"batch checkpoint response missing checkpoint_token",
),
)
})?;
self.set_checkpoint_token(new_token.to_string());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use aws_sdk_lambda::operation::checkpoint_durable_execution::CheckpointDurableExecutionOutput;
use aws_sdk_lambda::operation::get_durable_execution_state::GetDurableExecutionStateOutput;
use aws_sdk_lambda::types::{OperationStatus, OperationType, OperationUpdate};
struct TestBackend {
pages: Vec<(Vec<Operation>, Option<String>)>,
}
#[async_trait::async_trait]
impl DurableBackend for TestBackend {
async fn checkpoint(
&self,
_arn: &str,
_checkpoint_token: &str,
_updates: Vec<OperationUpdate>,
_client_token: Option<&str>,
) -> Result<CheckpointDurableExecutionOutput, DurableError> {
unimplemented!("not needed for context tests")
}
async fn get_execution_state(
&self,
_arn: &str,
_checkpoint_token: &str,
next_marker: &str,
_max_items: i32,
) -> Result<GetDurableExecutionStateOutput, DurableError> {
let page_idx: usize = next_marker.parse().unwrap_or(0);
if page_idx >= self.pages.len() {
return Ok(GetDurableExecutionStateOutput::builder().build().unwrap());
}
let (ops, marker) = &self.pages[page_idx];
let mut builder = GetDurableExecutionStateOutput::builder();
for op in ops {
builder = builder.operations(op.clone());
}
if let Some(m) = marker {
builder = builder.next_marker(m);
}
Ok(builder.build().unwrap())
}
}
fn make_op(id: &str, status: OperationStatus) -> Operation {
Operation::builder()
.id(id)
.r#type(OperationType::Step)
.status(status)
.start_timestamp(aws_smithy_types::DateTime::from_secs(0))
.build()
.unwrap()
}
#[tokio::test]
async fn empty_history_creates_executing_context() {
let backend = Arc::new(TestBackend { pages: vec![] });
let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
.await
.unwrap();
assert_eq!(ctx.execution_mode(), ExecutionMode::Executing);
assert!(!ctx.is_replaying());
assert_eq!(ctx.arn(), "arn:test");
assert_eq!(ctx.checkpoint_token(), "tok");
}
#[tokio::test]
async fn initial_operations_loaded() {
let backend = Arc::new(TestBackend { pages: vec![] });
let ops = vec![make_op("op1", OperationStatus::Succeeded)];
let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), ops, None)
.await
.unwrap();
assert!(ctx.is_replaying());
assert!(ctx.replay_engine().check_result("op1").is_some());
}
#[tokio::test]
async fn pagination_loads_all_pages() {
let backend = Arc::new(TestBackend {
pages: vec![
(
vec![make_op("op2", OperationStatus::Succeeded)],
Some("1".to_string()),
),
(vec![make_op("op3", OperationStatus::Succeeded)], None),
],
});
let initial = vec![make_op("op1", OperationStatus::Succeeded)];
let ctx = DurableContext::new(
backend,
"arn:test".into(),
"tok".into(),
initial,
Some("0".to_string()),
)
.await
.unwrap();
assert!(ctx.replay_engine().check_result("op1").is_some());
assert!(ctx.replay_engine().check_result("op2").is_some());
assert!(ctx.replay_engine().check_result("op3").is_some());
}
#[tokio::test]
async fn set_checkpoint_token_updates() {
let backend = Arc::new(TestBackend { pages: vec![] });
let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok1".into(), vec![], None)
.await
.unwrap();
assert_eq!(ctx.checkpoint_token(), "tok1");
ctx.set_checkpoint_token("tok2".to_string());
assert_eq!(ctx.checkpoint_token(), "tok2");
}
#[tokio::test]
async fn new_context_has_empty_compensations() {
let backend = Arc::new(TestBackend { pages: vec![] });
let ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
.await
.unwrap();
assert_eq!(ctx.compensation_count(), 0);
}
#[tokio::test]
async fn create_child_context_has_empty_compensations_not_inherited() {
use crate::types::CompensationRecord;
let backend = Arc::new(TestBackend { pages: vec![] });
let mut ctx = DurableContext::new(backend, "arn:test".into(), "tok".into(), vec![], None)
.await
.unwrap();
let record = CompensationRecord {
name: "parent_comp".to_string(),
forward_result_json: serde_json::Value::Null,
compensate_fn: Box::new(|_| Box::pin(async { Ok(()) })),
};
ctx.push_compensation(record);
assert_eq!(ctx.compensation_count(), 1);
let child = ctx.create_child_context("some-op-id");
assert_eq!(
child.compensation_count(),
0,
"child context must NOT inherit parent compensations"
);
}
}