use serde_json::{Value, json};
use thiserror::Error;
use crate::{
channels::Channel,
event_bus::Event,
state::{StateKey, StateLifecycle, VersionedState},
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ReplayRun {
pub final_state: VersionedState,
pub events: Vec<Event>,
}
impl ReplayRun {
#[must_use]
pub fn new(final_state: VersionedState, events: Vec<Event>) -> Self {
Self {
final_state,
events,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ReplayComparison {
differences: Vec<String>,
}
impl ReplayComparison {
#[must_use]
pub fn matched() -> Self {
Self {
differences: Vec::new(),
}
}
#[must_use]
pub fn with_differences(differences: Vec<String>) -> Self {
Self { differences }
}
#[must_use]
pub fn is_match(&self) -> bool {
self.differences.is_empty()
}
#[must_use]
pub fn differences(&self) -> &[String] {
&self.differences
}
pub fn assert_matches(self) -> Result<(), ReplayConformanceError> {
if self.is_match() {
Ok(())
} else {
Err(ReplayConformanceError::Mismatch {
differences: self.differences,
})
}
}
}
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))]
#[non_exhaustive]
pub enum ReplayConformanceError {
#[error("replay conformance mismatch: {differences:?}")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(code(weavegraph::replay::mismatch))
)]
Mismatch {
differences: Vec<String>,
},
}
#[must_use]
pub fn normalize_event(event: &Event) -> Value {
let mut value = event.to_json_value();
if let Value::Object(object) = &mut value {
object.remove("timestamp");
}
value
}
#[must_use]
pub fn normalize_state(state: &VersionedState) -> Value {
json!({
"messages": state.messages.snapshot(),
"messages_version": state.messages.version(),
"extra": state.extra.snapshot(),
"extra_version": state.extra.version(),
"errors": state.errors.snapshot(),
"errors_version": state.errors.version(),
})
}
#[must_use]
pub fn compare_final_state(left: &VersionedState, right: &VersionedState) -> ReplayComparison {
let left_value = normalize_state(left);
let right_value = normalize_state(right);
if left_value == right_value {
ReplayComparison::matched()
} else {
ReplayComparison::with_differences(vec![format!(
"final state differs: left={left_value} right={right_value}"
)])
}
}
#[must_use]
pub fn compare_event_sequences(left: &[Event], right: &[Event]) -> ReplayComparison {
compare_event_sequences_with(left, right, normalize_event)
}
#[must_use]
pub fn compare_event_sequences_with<F>(
left: &[Event],
right: &[Event],
normalizer: F,
) -> ReplayComparison
where
F: Fn(&Event) -> Value,
{
let left_values: Vec<Value> = left.iter().map(&normalizer).collect();
let right_values: Vec<Value> = right.iter().map(&normalizer).collect();
if left_values == right_values {
return ReplayComparison::matched();
}
let mut differences = Vec::new();
if left_values.len() != right_values.len() {
differences.push(format!(
"event count differs: left={} right={}",
left_values.len(),
right_values.len()
));
}
let shared_len = left_values.len().min(right_values.len());
for index in 0..shared_len {
if left_values[index] != right_values[index] {
differences.push(format!(
"event {index} differs: left={} right={}",
left_values[index], right_values[index]
));
break;
}
}
ReplayComparison::with_differences(differences)
}
#[must_use]
pub fn compare_replay_runs(left: &ReplayRun, right: &ReplayRun) -> ReplayComparison {
compare_replay_runs_with(left, right, normalize_event)
}
#[must_use]
pub fn compare_replay_runs_with<F>(
left: &ReplayRun,
right: &ReplayRun,
event_normalizer: F,
) -> ReplayComparison
where
F: Fn(&Event) -> Value,
{
let mut differences = Vec::new();
let state_comparison = compare_final_state(&left.final_state, &right.final_state);
differences.extend(state_comparison.differences().iter().cloned());
let event_comparison =
compare_event_sequences_with(&left.events, &right.events, event_normalizer);
differences.extend(event_comparison.differences().iter().cloned());
ReplayComparison::with_differences(differences)
}
#[derive(Debug, Default, Clone)]
pub struct StateNormalizeProfile {
ignored: Vec<(String, Option<StateLifecycle>)>,
}
impl StateNormalizeProfile {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn ignore_extra_keys<I, S>(mut self, keys: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for k in keys {
self.add_raw(k.into(), None);
}
self
}
#[must_use]
pub fn ignore_key<T>(mut self, key: StateKey<T>) -> Self {
self.add_raw(key.storage_key(), Some(key.lifecycle()));
self
}
fn add_raw(&mut self, storage_key: String, lifecycle: Option<StateLifecycle>) {
if let Some((_, existing_lc)) = self.ignored.iter().find(|(k, _)| k == &storage_key) {
match (existing_lc, &lifecycle) {
(Some(a), Some(b)) if a != b => {
panic!(
"StateNormalizeProfile: conflicting lifecycle annotations for key {:?}: \
already registered as {:?}, attempted to re-register as {:?}. \
Ensure the same StateKey constant is used throughout.",
storage_key, a, b
);
}
_ => {} }
return;
}
self.ignored.push((storage_key, lifecycle));
}
pub fn ignored_keys(&self) -> impl Iterator<Item = &str> {
self.ignored.iter().map(|(k, _)| k.as_str())
}
}
#[must_use]
pub fn normalize_state_with(state: &VersionedState, profile: &StateNormalizeProfile) -> Value {
let mut extra = state.extra.snapshot();
for key in profile.ignored_keys() {
extra.remove(key);
}
json!({
"messages": state.messages.snapshot(),
"messages_version": state.messages.version(),
"extra": extra,
"extra_version": state.extra.version(),
"errors": state.errors.snapshot(),
"errors_version": state.errors.version(),
})
}
#[must_use]
pub fn compare_final_state_with(
left: &VersionedState,
right: &VersionedState,
profile: &StateNormalizeProfile,
) -> ReplayComparison {
let left_value = normalize_state_with(left, profile);
let right_value = normalize_state_with(right, profile);
if left_value == right_value {
ReplayComparison::matched()
} else {
ReplayComparison::with_differences(vec![format!(
"final state differs: left={left_value} right={right_value}"
)])
}
}
#[must_use]
pub fn compare_replay_runs_with_profile<F>(
left: &ReplayRun,
right: &ReplayRun,
state_profile: &StateNormalizeProfile,
event_normalizer: F,
) -> ReplayComparison
where
F: Fn(&Event) -> Value,
{
let mut differences = Vec::new();
let state_comparison =
compare_final_state_with(&left.final_state, &right.final_state, state_profile);
differences.extend(state_comparison.differences().iter().cloned());
let event_comparison =
compare_event_sequences_with(&left.events, &right.events, event_normalizer);
differences.extend(event_comparison.differences().iter().cloned());
ReplayComparison::with_differences(differences)
}