use std::sync::Arc;
use aion_core::{Event, RunId, WorkflowId};
use aion_package::ContentHash;
use aion_store::EventStore;
use chrono::{DateTime, Utc};
use crate::durability::{
Command, DurabilityError, LiveExecutor, Recorder, Replay, ReplayOutcome, ReplayTerminal,
Resolution, fail_on_violation,
};
use crate::supervision::spawn_workflow_with_policy;
use crate::{EngineError, LoadedWorkflows, Pid, RuntimeHandle, RuntimeInput};
#[derive(Clone, Debug)]
pub struct RecoveryPlan {
pub run_id: RunId,
pub commands: Vec<Command>,
pub failure_recorded_at: DateTime<Utc>,
}
struct StartedMetadata<'a> {
input: &'a aion_core::Payload,
run_id: &'a RunId,
}
fn started_metadata<'a>(
workflow_id: &WorkflowId,
expected_workflow_type: &str,
history: &'a [Event],
) -> Result<StartedMetadata<'a>, EngineError> {
let Some((workflow_type, input, run_id)) = history.iter().rev().find_map(|event| match event {
Event::WorkflowStarted {
workflow_type,
input,
run_id,
..
} => Some((workflow_type, input, run_id)),
_ => None,
}) else {
return Err(EngineError::Load {
reason: format!(
"active workflow `{workflow_id}` has no WorkflowStarted event in durable history"
),
});
};
if workflow_type != expected_workflow_type {
return Err(EngineError::Load {
reason: format!(
"active workflow `{workflow_id}` started as `{workflow_type}` but recovery was requested for `{expected_workflow_type}`"
),
});
}
Ok(StartedMetadata { input, run_id })
}
pub trait RecoveryDriver: Send + Sync {
fn recovery_plan(
&self,
workflow_id: &WorkflowId,
history: &[Event],
) -> Result<RecoveryPlan, DurabilityError>;
}
#[derive(Clone, Debug, PartialEq)]
pub struct RecoveryResumePoint {
pub command_index: usize,
pub command: Command,
pub head: u64,
}
#[derive(Debug)]
pub enum RecoveryOutcome {
Resumed {
resume_point: RecoveryResumePoint,
recorded: Vec<Resolution>,
},
Terminal {
terminal: ReplayTerminal,
recorded: Vec<Resolution>,
head: u64,
},
Failed {
error: DurabilityError,
failure_recorded: bool,
},
}
#[derive(Debug)]
pub struct RecoveryReport {
pub workflow_id: WorkflowId,
pub outcome: RecoveryOutcome,
}
pub async fn recover(
store: Arc<dyn EventStore>,
executor: &dyn LiveExecutor,
driver: &dyn RecoveryDriver,
) -> Result<Vec<RecoveryReport>, DurabilityError> {
let active = store.list_active().await?;
let mut reports = Vec::with_capacity(active.len());
for workflow_id in active {
let outcome = recover_one(Arc::clone(&store), executor, driver, &workflow_id)
.await
.unwrap_or_else(|error| RecoveryOutcome::Failed {
error,
failure_recorded: false,
});
reports.push(RecoveryReport {
workflow_id,
outcome,
});
}
Ok(reports)
}
async fn recover_one(
store: Arc<dyn EventStore>,
executor: &dyn LiveExecutor,
driver: &dyn RecoveryDriver,
workflow_id: &WorkflowId,
) -> Result<RecoveryOutcome, DurabilityError> {
let history = store.read_history(workflow_id).await?;
let head = history.last().map(Event::seq).unwrap_or_default();
let mut recorder = Recorder::resume_at(workflow_id.clone(), Arc::clone(&store), head);
let plan = driver.recovery_plan(workflow_id, &history)?;
let mut replay = Replay::with_handoff(workflow_id, &plan.run_id, history, &recorder, executor)?;
match replay.drive(plan.commands) {
Ok(ReplayOutcome::ResumeLive {
command_index,
command,
recorded,
}) => Ok(RecoveryOutcome::Resumed {
resume_point: RecoveryResumePoint {
command_index,
command,
head,
},
recorded,
}),
Ok(ReplayOutcome::Terminal { terminal, recorded }) => Ok(RecoveryOutcome::Terminal {
terminal,
recorded,
head,
}),
Ok(ReplayOutcome::AwaitingCommand { recorded }) => Err(DurabilityError::HistoryShape {
reason: format!(
"recovery command stream ended before workflow {workflow_id} reached terminal or resume point after {} recorded resolutions at head {head}",
recorded.len()
),
}),
Err(DurabilityError::NonDeterminism(violation)) => {
fail_on_violation(&mut recorder, plan.failure_recorded_at, &violation).await?;
Ok(RecoveryOutcome::Failed {
error: DurabilityError::NonDeterminism(violation),
failure_recorded: true,
})
}
Err(error) => Err(error),
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ActiveWorkflowRecovery {
Resident {
run_id: RunId,
loaded_version: ContentHash,
pid: Pid,
},
ScheduleCoordinator {
run_id: RunId,
},
}
pub trait ActiveWorkflowRecoverySeam: Send + Sync {
fn recover_active_workflow(
&self,
workflow_id: &WorkflowId,
workflow_type: &str,
history: &[Event],
loaded_workflows: &LoadedWorkflows,
) -> Result<ActiveWorkflowRecovery, EngineError>;
}
pub struct ActiveWorkflowRecoverySeamImpl {
runtime: Arc<RuntimeHandle>,
}
impl ActiveWorkflowRecoverySeamImpl {
#[must_use]
pub fn new(runtime: Arc<RuntimeHandle>) -> Self {
Self { runtime }
}
}
impl ActiveWorkflowRecoverySeam for ActiveWorkflowRecoverySeamImpl {
fn recover_active_workflow(
&self,
workflow_id: &WorkflowId,
workflow_type: &str,
history: &[Event],
loaded_workflows: &LoadedWorkflows,
) -> Result<ActiveWorkflowRecovery, EngineError> {
let started = started_metadata(workflow_id, workflow_type, history)?;
let loaded = loaded_workflows
.single_loaded(workflow_type)
.map_err(|reason| EngineError::Load { reason })?;
let runtime_input = RuntimeInput::from_payload(started.input)?;
let pid = spawn_workflow_with_policy(
&self.runtime,
loaded.deployed_entry_module(),
loaded.entry_function(),
runtime_input,
)?;
Ok(ActiveWorkflowRecovery::Resident {
run_id: started.run_id.clone(),
loaded_version: loaded.version().clone(),
pid,
})
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use aion_core::{
Event, EventEnvelope, Payload, RunId, TimerId, WorkflowFilter, WorkflowId, WorkflowSummary,
};
use aion_store::{
EventStore, ReadableEventStore, RunSummary, StoreError, TimerEntry, WritableEventStore,
WriteToken,
};
use async_trait::async_trait;
use chrono::{DateTime, TimeZone, Utc};
use serde_json::json;
use uuid::Uuid;
use super::{RecoveryDriver, RecoveryPlan, recover};
use crate::durability::{
Command, CorrelationKey, DurabilityError, LiveActivityOutcome, LiveChildOutcome,
LiveExecutor, RecoveryOutcome,
};
type TestResult<T = ()> = Result<T, Box<dyn std::error::Error>>;
#[derive(Default)]
struct CountingStore {
active: Mutex<Vec<WorkflowId>>,
histories: Mutex<HashMap<WorkflowId, Vec<Event>>>,
reads: Mutex<Vec<WorkflowId>>,
}
#[async_trait]
impl WritableEventStore for CountingStore {
async fn append(
&self,
_token: WriteToken,
workflow_id: &WorkflowId,
events: &[Event],
expected_seq: u64,
) -> Result<(), StoreError> {
let mut histories = self
.histories
.lock()
.map_err(|error| StoreError::Backend(format!("history lock poisoned: {error}")))?;
let current = histories
.get(workflow_id)
.and_then(|history| history.last())
.map(Event::seq)
.unwrap_or_default();
if current != expected_seq {
return Err(StoreError::SequenceConflict {
expected: expected_seq,
found: current,
});
}
histories
.entry(workflow_id.clone())
.or_default()
.extend(events.iter().cloned());
Ok(())
}
}
#[async_trait]
impl ReadableEventStore for CountingStore {
async fn read_history(&self, workflow_id: &WorkflowId) -> Result<Vec<Event>, StoreError> {
self.reads
.lock()
.map_err(|error| StoreError::Backend(format!("read lock poisoned: {error}")))?
.push(workflow_id.clone());
Ok(self
.histories
.lock()
.map_err(|error| StoreError::Backend(format!("history lock poisoned: {error}")))?
.get(workflow_id)
.cloned()
.unwrap_or_default())
}
async fn read_history_from(
&self,
workflow_id: &WorkflowId,
from_seq: u64,
) -> Result<Vec<Event>, StoreError> {
let history = self.read_history(workflow_id).await?;
Ok(history
.into_iter()
.filter(|event| event.seq() >= from_seq)
.collect())
}
async fn read_run_chain(
&self,
workflow_id: &WorkflowId,
) -> Result<Vec<RunSummary>, StoreError> {
let history = self.read_history(workflow_id).await?;
aion_store::run_chain::run_chain_from_history(&history)
}
async fn list_workflow_ids(&self) -> Result<Vec<WorkflowId>, StoreError> {
let mut workflow_ids = self
.histories
.lock()
.map_err(|error| StoreError::Backend(format!("history lock poisoned: {error}")))?
.keys()
.cloned()
.collect::<Vec<_>>();
workflow_ids.sort_by_key(ToString::to_string);
Ok(workflow_ids)
}
async fn list_active(&self) -> Result<Vec<WorkflowId>, StoreError> {
Ok(self
.active
.lock()
.map_err(|error| StoreError::Backend(format!("active lock poisoned: {error}")))?
.clone())
}
async fn query(&self, filter: &WorkflowFilter) -> Result<Vec<WorkflowSummary>, StoreError> {
let _ = filter;
Ok(Vec::new())
}
async fn schedule_timer(
&self,
workflow_id: &WorkflowId,
timer_id: &TimerId,
fire_at: DateTime<Utc>,
) -> Result<(), StoreError> {
let _ = (workflow_id, timer_id, fire_at);
Ok(())
}
async fn expired_timers(
&self,
as_of: DateTime<Utc>,
) -> Result<Vec<TimerEntry>, StoreError> {
let _ = as_of;
Ok(Vec::new())
}
}
struct StaticDriver;
impl RecoveryDriver for StaticDriver {
fn recovery_plan(
&self,
workflow_id: &WorkflowId,
history: &[Event],
) -> Result<RecoveryPlan, DurabilityError> {
let _ = history;
let activity_type = format!("activity-{workflow_id}");
Ok(RecoveryPlan {
run_id: RunId::new(Uuid::from_u128(10)),
commands: vec![Command::RunActivity {
key: CorrelationKey::Activity(0),
activity_type,
input: payload("activity-input")?,
}],
failure_recorded_at: timestamp(99)?,
})
}
}
struct NoLiveExecutor;
#[async_trait]
impl LiveExecutor for NoLiveExecutor {
async fn run_activity(
&self,
activity_type: String,
input: Payload,
) -> Result<LiveActivityOutcome, DurabilityError> {
let _ = (activity_type, input);
Err(DurabilityError::HistoryShape {
reason: "recovery replay must not execute live activity".to_owned(),
})
}
async fn start_timer(
&self,
timer_id: TimerId,
fire_at: DateTime<Utc>,
) -> Result<(), DurabilityError> {
let _ = (timer_id, fire_at);
Err(DurabilityError::HistoryShape {
reason: "recovery replay must not execute live timer".to_owned(),
})
}
async fn await_signal(
&self,
name: String,
index: usize,
) -> Result<Payload, DurabilityError> {
let _ = (name, index);
Err(DurabilityError::HistoryShape {
reason: "recovery replay must not execute live signal".to_owned(),
})
}
async fn spawn_child(
&self,
workflow_type: String,
input: Payload,
) -> Result<LiveChildOutcome, DurabilityError> {
let _ = (workflow_type, input);
Err(DurabilityError::HistoryShape {
reason: "recovery replay must not execute live child".to_owned(),
})
}
}
fn timestamp(seconds: i64) -> Result<DateTime<Utc>, DurabilityError> {
Utc.timestamp_opt(seconds, 0)
.single()
.ok_or_else(|| DurabilityError::HistoryShape {
reason: format!("invalid timestamp {seconds}"),
})
}
fn payload(label: &str) -> Result<Payload, DurabilityError> {
Payload::from_json(&json!({ "label": label })).map_err(|error| {
DurabilityError::HistoryShape {
reason: format!("invalid test payload: {error}"),
}
})
}
fn started_event(workflow_id: WorkflowId) -> Result<Event, DurabilityError> {
Ok(Event::WorkflowStarted {
envelope: EventEnvelope {
seq: 1,
recorded_at: timestamp(1)?,
workflow_id,
},
workflow_type: "workflow".to_owned(),
input: payload("input")?,
run_id: aion_core::RunId::new(uuid::Uuid::from_u128(10)),
parent_run_id: None,
})
}
#[tokio::test]
async fn recover_lists_active_and_reads_each_history() -> TestResult {
let first = WorkflowId::new(Uuid::from_u128(1));
let second = WorkflowId::new(Uuid::from_u128(2));
let store = Arc::new(CountingStore::default());
store
.active
.lock()
.map_err(|error| format!("active lock poisoned: {error}"))?
.extend([first.clone(), second.clone()]);
store
.histories
.lock()
.map_err(|error| format!("history lock poisoned: {error}"))?
.extend([
(first.clone(), vec![started_event(first.clone())?]),
(second.clone(), vec![started_event(second.clone())?]),
]);
let store_for_recovery: Arc<dyn EventStore> = store.clone();
let report = recover(store_for_recovery, &NoLiveExecutor, &StaticDriver).await?;
assert_eq!(report.len(), 2);
assert!(
report
.iter()
.all(|entry| matches!(entry.outcome, RecoveryOutcome::Resumed { .. }))
);
let reads = store
.reads
.lock()
.map_err(|error| format!("read lock poisoned: {error}"))?
.clone();
assert_eq!(reads, vec![first, second]);
Ok(())
}
}