1use trellis_core::{
2 GraphError, OutputFrameTrace, ResourceCommandTrace, ResourceKey, Revision, TraceMismatch,
3 TransactionId, TransactionResult, TransactionTrace, assert_transaction_traces_match,
4};
5
6use crate::{FullRecomputeOracle, OracleCheck, OracleMismatch, assert_incremental_equals_full};
7
8#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct ScenarioStep {
11 pub name: String,
13 pub trace: TransactionTrace,
15}
16
17#[derive(Clone, Debug, Default, Eq, PartialEq)]
19pub struct Scenario {
20 steps: Vec<ScenarioStep>,
21}
22
23#[derive(Clone, Debug, Eq, PartialEq)]
25pub enum ScenarioError {
26 ReplayMismatch(TraceMismatch),
28 ReplayFinalStateMismatch {
30 expected: String,
32 actual: String,
34 },
35 ReplayLedgerMismatch {
37 field: &'static str,
39 expected: String,
41 actual: String,
43 },
44 MissingStep(String),
46 StepMismatch {
48 step: String,
50 transaction_id: TransactionId,
52 revision: Revision,
54 field: &'static str,
56 expected: String,
58 actual: String,
60 },
61 StepCommitFailed {
63 step: String,
65 error: GraphError,
67 },
68 InvariantFailed {
70 step: String,
72 invariant: String,
74 transaction_id: TransactionId,
76 revision: Revision,
78 },
79}
80
81pub trait TraceRedactor {
83 fn step_name(&self, name: &str) -> String {
85 name.to_owned()
86 }
87
88 fn resource_key(&self, key: &ResourceKey) -> ResourceKey {
90 key.clone()
91 }
92
93 fn invariant_name(&self, name: &str) -> String {
95 name.to_owned()
96 }
97}
98
99#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
101pub struct NoRedaction;
102
103impl TraceRedactor for NoRedaction {}
104
105impl Scenario {
106 pub fn new() -> Self {
108 Self::default()
109 }
110
111 pub fn record<C, O>(&mut self, name: impl Into<String>, result: &TransactionResult<C, O>) {
113 self.record_trace(name, result.trace());
114 }
115
116 pub fn record_trace(&mut self, name: impl Into<String>, trace: TransactionTrace) {
118 self.steps.push(ScenarioStep {
119 name: name.into(),
120 trace,
121 });
122 }
123
124 pub fn steps(&self) -> &[ScenarioStep] {
126 &self.steps
127 }
128
129 pub fn step(&self, name: &str) -> Result<&ScenarioStep, ScenarioError> {
131 self.steps
132 .iter()
133 .find(|step| step.name == name)
134 .ok_or_else(|| ScenarioError::MissingStep(name.to_owned()))
135 }
136
137 pub fn assert_replay_matches(&self, other: &Scenario) -> Result<(), ScenarioError> {
139 assert_transaction_traces_match(&self.traces(), &other.traces())
140 .map_err(ScenarioError::ReplayMismatch)
141 }
142
143 pub fn traces(&self) -> Vec<TransactionTrace> {
145 self.steps
146 .iter()
147 .map(|step| step.trace.clone())
148 .collect::<Vec<_>>()
149 }
150
151 pub fn resource_commands(&self) -> Vec<ResourceCommandTrace> {
153 self.steps
154 .iter()
155 .flat_map(|step| step.trace.resource_commands.iter().cloned())
156 .collect()
157 }
158
159 pub fn output_frames(&self) -> Vec<OutputFrameTrace> {
161 self.steps
162 .iter()
163 .flat_map(|step| step.trace.output_frames.iter().cloned())
164 .collect()
165 }
166
167 pub fn assert_step_resource_commands(
169 &self,
170 name: &str,
171 expected: &[ResourceCommandTrace],
172 ) -> Result<(), ScenarioError> {
173 let step = self.step(name)?;
174 if step.trace.resource_commands == expected {
175 Ok(())
176 } else {
177 Err(ScenarioError::StepMismatch {
178 step: name.to_owned(),
179 transaction_id: step.trace.transaction_id,
180 revision: step.trace.revision,
181 field: "resource_commands",
182 expected: format!("{expected:#?}"),
183 actual: format!("{:#?}", step.trace.resource_commands),
184 })
185 }
186 }
187
188 pub fn assert_step_output_frames(
190 &self,
191 name: &str,
192 expected: &[OutputFrameTrace],
193 ) -> Result<(), ScenarioError> {
194 let step = self.step(name)?;
195 if step.trace.output_frames == expected {
196 Ok(())
197 } else {
198 Err(ScenarioError::StepMismatch {
199 step: name.to_owned(),
200 transaction_id: step.trace.transaction_id,
201 revision: step.trace.revision,
202 field: "output_frames",
203 expected: format!("{expected:#?}"),
204 actual: format!("{:#?}", step.trace.output_frames),
205 })
206 }
207 }
208
209 pub fn assert_oracle<G, O>(
211 &self,
212 graph: &G,
213 inputs: &O::CanonicalInputs,
214 ) -> Result<OracleCheck<O::ExpectedState>, OracleMismatch<O::ExpectedState>>
215 where
216 O: FullRecomputeOracle<G>,
217 {
218 assert_incremental_equals_full::<G, O>(graph, inputs)
219 }
220
221 pub fn redacted(&self, redactor: &impl TraceRedactor) -> Self {
223 let steps = self
224 .steps
225 .iter()
226 .map(|step| ScenarioStep {
227 name: redactor.step_name(&step.name),
228 trace: redact_trace(&step.trace, redactor),
229 })
230 .collect();
231 Self { steps }
232 }
233
234 pub fn to_redacted_debug_string(&self, redactor: &impl TraceRedactor) -> String {
236 format!("{:#?}", self.redacted(redactor))
237 }
238}
239
240fn redact_trace(trace: &TransactionTrace, redactor: &impl TraceRedactor) -> TransactionTrace {
241 let mut trace = trace.clone();
242 for command in &mut trace.resource_commands {
243 command.key = redactor.resource_key(&command.key);
244 }
245 for result in &mut trace.invariant_results {
246 result.name = redactor.invariant_name(&result.name);
247 }
248 trace
249}