use trellis_core::TransactionTrace;
use crate::{Scenario, ScenarioError};
pub const TRACE_FORMAT_VERSION: u32 = 1;
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
pub struct DataTransactionScript<Operation> {
format_version: u32,
steps: Vec<DataScriptStep<Operation>>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
pub struct DataScriptStep<Operation> {
name: String,
operations: Vec<Operation>,
}
pub struct DataScriptStepBuilder<'script, Operation> {
script: &'script mut DataTransactionScript<Operation>,
name: String,
operations: Vec<Operation>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
pub struct SerializedScenario {
format_version: u32,
steps: Vec<SerializedScenarioStep>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
pub struct SerializedScenarioStep {
pub name: String,
pub trace: TransactionTrace,
}
impl<Operation> DataTransactionScript<Operation> {
pub fn new() -> Self {
Self {
format_version: TRACE_FORMAT_VERSION,
steps: Vec::new(),
}
}
pub fn step(&mut self, name: impl Into<String>) -> DataScriptStepBuilder<'_, Operation> {
DataScriptStepBuilder {
script: self,
name: name.into(),
operations: Vec::new(),
}
}
pub fn format_version(&self) -> u32 {
self.format_version
}
pub fn steps(&self) -> &[DataScriptStep<Operation>] {
&self.steps
}
pub fn validate_format_version(&self) -> Result<(), ScenarioError> {
validate_format_version(self.format_version)
}
#[cfg(feature = "serde")]
pub fn to_json(&self) -> Result<String, serde_json::Error>
where
Operation: serde::Serialize,
{
serde_json::to_string_pretty(self)
}
#[cfg(feature = "serde")]
pub fn from_json(json: &str) -> Result<Self, serde_json::Error>
where
Operation: serde::de::DeserializeOwned,
{
serde_json::from_str(json)
}
}
impl<Operation> Default for DataTransactionScript<Operation> {
fn default() -> Self {
Self::new()
}
}
impl<Operation> DataScriptStep<Operation> {
pub fn name(&self) -> &str {
&self.name
}
pub fn operations(&self) -> &[Operation] {
&self.operations
}
}
impl<Operation> DataScriptStepBuilder<'_, Operation> {
pub fn operation(mut self, operation: Operation) -> Self {
self.operations.push(operation);
self
}
pub fn commit(self) {
self.script.steps.push(DataScriptStep {
name: self.name,
operations: self.operations,
});
}
}
impl SerializedScenario {
pub fn from_scenario(scenario: &Scenario) -> Self {
Self {
format_version: TRACE_FORMAT_VERSION,
steps: scenario
.steps()
.iter()
.map(|step| SerializedScenarioStep {
name: step.name.clone(),
trace: step.trace.clone(),
})
.collect(),
}
}
pub fn format_version(&self) -> u32 {
self.format_version
}
pub fn steps(&self) -> &[SerializedScenarioStep] {
&self.steps
}
pub fn into_scenario(self) -> Result<Scenario, ScenarioError> {
validate_format_version(self.format_version)?;
let mut scenario = Scenario::new();
for step in self.steps {
scenario.record_trace(step.name, step.trace)?;
}
Ok(scenario)
}
pub fn assert_matches_scenario(&self, actual: &Scenario) -> Result<(), ScenarioError> {
self.clone().into_scenario()?.assert_replay_matches(actual)
}
#[cfg(feature = "serde")]
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(self)
}
#[cfg(feature = "serde")]
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
fn validate_format_version(actual: u32) -> Result<(), ScenarioError> {
if actual == TRACE_FORMAT_VERSION {
Ok(())
} else {
Err(ScenarioError::TraceFormatVersionMismatch {
expected: TRACE_FORMAT_VERSION,
actual,
})
}
}