use crate::{
constraint_vm::{
self,
error::{CheckError, ConstraintErrors, ConstraintsUnsatisfied},
},
state_read_vm::{
self, asm::FromBytesError, error::StateReadError, Access, BytecodeMapped, Gas, GasLimit,
SolutionAccess, StateRead, StateSlotSlice, StateSlots,
},
types::{
predicate::{Directive, Predicate},
solution::{Solution, SolutionData, SolutionDataIndex},
Key, PredicateAddress, Word,
},
};
use constraint_vm::TransientData;
#[cfg(feature = "tracing")]
use essential_hash::content_addr;
use std::{collections::HashSet, fmt, sync::Arc};
use thiserror::Error;
use tokio::task::JoinSet;
#[cfg(feature = "tracing")]
use tracing::Instrument;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct CheckPredicateConfig {
pub collect_all_failures: bool,
}
#[derive(Debug, Error)]
pub enum InvalidSolution {
#[error("invalid solution data: {0}")]
Data(#[from] InvalidSolutionData),
#[error("state mutations validation failed: {0}")]
StateMutations(#[from] InvalidStateMutations),
#[error("transient data validation failed: {0}")]
TransientData(#[from] InvalidTransientData),
}
#[derive(Debug, Error)]
pub enum InvalidSolutionData {
#[error("must be at least one solution data")]
Empty,
#[error("the number of solution data ({0}) exceeds the limit ({MAX_SOLUTION_DATA})")]
TooMany(usize),
#[error("data {0} expects too many decision vars {1} (limit: {MAX_DECISION_VARIABLES})")]
TooManyDecisionVariables(usize, usize),
#[error("Invalid state mutation entry: {0}")]
StateMutationEntry(KvError),
#[error("Invalid transient data entry: {0}")]
TransientDataEntry(KvError),
#[error("Decision variable value len {0} exceeds limit {MAX_VALUE_SIZE}")]
DecVarValueTooLarge(usize),
}
#[derive(Debug, Error)]
pub enum KvError {
#[error("key with length {0} exceeds limit {MAX_KEY_SIZE}")]
KeyTooLarge(usize),
#[error("value with length {0} exceeds limit {MAX_VALUE_SIZE}")]
ValueTooLarge(usize),
}
#[derive(Debug, Error)]
pub enum InvalidStateMutations {
#[error("the number of state mutations ({0}) exceeds the limit ({MAX_STATE_MUTATIONS})")]
TooMany(usize),
#[error("state mutation pathway {0} out of range of solution data")]
PathwayOutOfRangeOfSolutionData(u16),
#[error("attempt to apply multiple mutations to the same slot: {0:?} {1:?}")]
MultipleMutationsForSlot(PredicateAddress, Key),
}
#[derive(Debug, Error)]
pub enum InvalidTransientData {
#[error("the number of transient data ({0}) exceeds the limit ({MAX_TRANSIENT_DATA})")]
TooMany(usize),
}
#[derive(Debug, Error)]
pub enum PredicatesError<E> {
#[error("{0}")]
Failed(#[from] PredicateErrors<E>),
#[error("one or more spawned tasks failed to join: {0}")]
Join(#[from] tokio::task::JoinError),
#[error("summing solution data utility overflowed")]
UtilityOverflowed,
#[error("summing solution data gas overflowed")]
GasOverflowed,
}
#[derive(Debug, Error)]
pub struct PredicateErrors<E>(pub Vec<(SolutionDataIndex, PredicateError<E>)>);
#[derive(Debug, Error)]
pub enum PredicateError<E> {
#[error("failed to parse an op during bytecode mapping: {0}")]
OpsFromBytesError(#[from] FromBytesError),
#[error("state read execution error: {0}")]
StateRead(#[from] StateReadError<E>),
#[error("constraint checking failed: {0}")]
Constraints(#[from] PredicateConstraintsError),
}
#[derive(Debug, Error)]
#[error("number of solution data decision variables ({data}) differs from predicate ({predicate})")]
pub struct InvalidDecisionVariablesLength {
pub data: usize,
pub predicate: u32,
}
#[derive(Debug, Error)]
pub enum PredicateConstraintsError {
#[error("check failed: {0}")]
Check(#[from] constraint_vm::error::CheckError),
#[error("failed to recv: {0}")]
Recv(#[from] tokio::sync::oneshot::error::RecvError),
#[error("failed to calculate utility: {0}")]
Utility(#[from] UtilityError),
}
pub type Utility = f64;
#[derive(Debug, Error)]
pub enum UtilityError {
#[error("the range specified by the directive [{0}..{1}] is invalid")]
InvalidDirectiveRange(Word, Word),
#[error("invalid stack result after directive execution: {0}")]
InvalidStack(#[from] constraint_vm::error::StackError),
#[error("directive execution with constraint VM failed: {0}")]
Execution(#[from] constraint_vm::error::ConstraintError),
#[error("failed to recv: {0}")]
Recv(#[from] tokio::sync::oneshot::error::RecvError),
}
pub const MAX_DECISION_VARIABLES: u32 = 100;
pub const MAX_SOLUTION_DATA: usize = 100;
pub const MAX_STATE_MUTATIONS: usize = 1000;
pub const MAX_TRANSIENT_DATA: usize = 1000;
pub const MAX_VALUE_SIZE: usize = 10_000;
pub const MAX_KEY_SIZE: usize = 1000;
impl<E: fmt::Display> fmt::Display for PredicateErrors<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("predicate checking failed for one or more solution data:\n")?;
for (ix, err) in &self.0 {
f.write_str(&format!(" {ix}: {err}\n"))?;
}
Ok(())
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(solution = %content_addr(&solution.data)), err))]
pub fn check(solution: &Solution) -> Result<(), InvalidSolution> {
check_data(&solution.data)?;
check_state_mutations(solution)?;
check_transient_data(solution)?;
Ok(())
}
fn check_value_size(value: &[Word]) -> Result<(), KvError> {
if value.len() > MAX_VALUE_SIZE {
Err(KvError::ValueTooLarge(value.len()))
} else {
Ok(())
}
}
fn check_key_size(value: &[Word]) -> Result<(), KvError> {
if value.len() > MAX_KEY_SIZE {
Err(KvError::KeyTooLarge(value.len()))
} else {
Ok(())
}
}
pub fn check_data(data_slice: &[SolutionData]) -> Result<(), InvalidSolutionData> {
if data_slice.is_empty() {
return Err(InvalidSolutionData::Empty);
}
if data_slice.len() > MAX_SOLUTION_DATA {
return Err(InvalidSolutionData::TooMany(data_slice.len()));
}
for (data_ix, data) in data_slice.iter().enumerate() {
if data.decision_variables.len() > MAX_DECISION_VARIABLES as usize {
return Err(InvalidSolutionData::TooManyDecisionVariables(
data_ix,
data.decision_variables.len(),
));
}
for v in &data.decision_variables {
check_value_size(v).map_err(|_| InvalidSolutionData::DecVarValueTooLarge(v.len()))?;
}
}
Ok(())
}
pub fn check_state_mutations(solution: &Solution) -> Result<(), InvalidSolution> {
if solution.state_mutations_len() > MAX_STATE_MUTATIONS {
return Err(InvalidStateMutations::TooMany(solution.state_mutations_len()).into());
}
for data in &solution.data {
let mut mut_keys = HashSet::new();
for mutation in &data.state_mutations {
if !mut_keys.insert(&mutation.key) {
return Err(InvalidStateMutations::MultipleMutationsForSlot(
data.predicate_to_solve.clone(),
mutation.key.clone(),
)
.into());
}
check_key_size(&mutation.key).map_err(InvalidSolutionData::StateMutationEntry)?;
check_value_size(&mutation.value).map_err(InvalidSolutionData::StateMutationEntry)?;
}
}
Ok(())
}
pub fn check_transient_data(solution: &Solution) -> Result<(), InvalidSolution> {
if solution.transient_data_len() > MAX_TRANSIENT_DATA {
return Err(InvalidTransientData::TooMany(solution.transient_data_len()).into());
}
for data in &solution.data {
for mutation in &data.transient_data {
check_key_size(&mutation.key).map_err(InvalidSolutionData::TransientDataEntry)?;
check_value_size(&mutation.value).map_err(InvalidSolutionData::TransientDataEntry)?;
}
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all))]
pub async fn check_predicates<SA, SB>(
pre_state: &SA,
post_state: &SB,
solution: Arc<Solution>,
get_predicate: impl Fn(&PredicateAddress) -> Arc<Predicate>,
config: Arc<CheckPredicateConfig>,
) -> Result<(Utility, Gas), PredicatesError<SA::Error>>
where
SA: Clone + StateRead + Send + Sync + 'static,
SB: Clone + StateRead<Error = SA::Error> + Send + Sync + 'static,
SA::Future: Send,
SB::Future: Send,
SA::Error: Send,
{
#[cfg(feature = "tracing")]
tracing::trace!("{}", essential_hash::content_addr(&*solution));
let transient_data: Arc<TransientData> =
Arc::new(essential_constraint_vm::transient_data(&solution));
let mut set: JoinSet<(_, Result<_, PredicateError<SA::Error>>)> = JoinSet::new();
for (solution_data_index, data) in solution.data.iter().enumerate() {
let solution_data_index: SolutionDataIndex = solution_data_index
.try_into()
.expect("solution data index already validated");
let predicate = get_predicate(&data.predicate_to_solve);
let solution = solution.clone();
let transient_data = transient_data.clone();
let pre_state: SA = pre_state.clone();
let post_state: SB = post_state.clone();
let config = config.clone();
let future = async move {
let pre_state = pre_state;
let post_state = post_state;
let res = check_predicate(
&pre_state,
&post_state,
solution,
predicate,
solution_data_index,
&config,
transient_data,
)
.await;
(solution_data_index, res)
};
#[cfg(feature = "tracing")]
let future = future.in_current_span();
set.spawn(future);
}
let mut total_gas: u64 = 0;
let mut utility: f64 = 0.0;
let mut failed = vec![];
while let Some(res) = set.join_next().await {
let (solution_data_ix, res) = res?;
let (u, g) = match res {
Ok(ok) => ok,
Err(e) => {
failed.push((solution_data_ix, e));
if config.collect_all_failures {
continue;
} else {
return Err(PredicateErrors(failed).into());
}
}
};
utility += u;
if utility == f64::INFINITY {
return Err(PredicatesError::UtilityOverflowed);
}
total_gas = total_gas
.checked_add(g)
.ok_or(PredicatesError::GasOverflowed)?;
}
if !failed.is_empty() {
return Err(PredicateErrors(failed).into());
}
Ok((utility, total_gas))
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(
skip_all,
fields(
solution = %format!("{}", content_addr(&*solution))[0..8],
data={solution_data_index},
),
),
)]
pub async fn check_predicate<SA, SB>(
pre_state: &SA,
post_state: &SB,
solution: Arc<Solution>,
predicate: Arc<Predicate>,
solution_data_index: SolutionDataIndex,
config: &CheckPredicateConfig,
transient_data: Arc<TransientData>,
) -> Result<(Utility, Gas), PredicateError<SA::Error>>
where
SA: StateRead + Sync,
SB: StateRead<Error = SA::Error> + Sync,
{
let mut total_gas = 0;
let mut pre_slots: Vec<Vec<Word>> = Vec::new();
let mut post_slots: Vec<Vec<Word>> = Vec::new();
let mutable_keys = constraint_vm::mut_keys_set(&solution, solution_data_index);
let solution_access = SolutionAccess::new(
&solution,
solution_data_index,
&mutable_keys,
&transient_data,
);
for (state_read_index, state_read) in predicate.state_read.iter().enumerate() {
#[cfg(not(feature = "tracing"))]
let _ = state_read_index;
let state_read_mapped = BytecodeMapped::try_from(&state_read[..])?;
let future = read_state_slots(
&state_read_mapped,
Access {
solution: solution_access,
state_slots: StateSlots {
pre: &pre_slots,
post: &post_slots,
},
},
pre_state,
);
#[cfg(feature = "tracing")]
let (gas, new_pre_slots) = future
.instrument(tracing::info_span!("pre", ix = state_read_index))
.await?;
#[cfg(not(feature = "tracing"))]
let (gas, new_pre_slots) = future.await?;
total_gas += gas;
pre_slots.extend(new_pre_slots);
let future = read_state_slots(
&state_read_mapped,
Access {
solution: solution_access,
state_slots: StateSlots {
pre: &pre_slots,
post: &post_slots,
},
},
post_state,
);
#[cfg(feature = "tracing")]
let (gas, new_post_slots) = future
.instrument(tracing::info_span!("post", ix = state_read_index))
.await?;
#[cfg(not(feature = "tracing"))]
let (gas, new_post_slots) = future.await?;
total_gas += gas;
post_slots.extend(new_post_slots);
}
let utility = check_predicate_constraints(
solution,
solution_data_index,
predicate.clone(),
Arc::from(pre_slots.into_boxed_slice()),
Arc::from(post_slots.into_boxed_slice()),
config,
transient_data,
)
.await?;
Ok((utility, total_gas))
}
async fn read_state_slots<S>(
bytecode_mapped: &BytecodeMapped<&[u8]>,
access: Access<'_>,
state_read: &S,
) -> Result<(Gas, Vec<Vec<Word>>), state_read_vm::error::StateReadError<S::Error>>
where
S: StateRead,
{
let mut vm = state_read_vm::Vm::default();
let gas_spent = vm
.exec_bytecode(
bytecode_mapped,
access,
state_read,
&|_: &state_read_vm::asm::Op| 1,
GasLimit::UNLIMITED,
)
.await?;
Ok((gas_spent, vm.into_state_slots()))
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, "check"))]
pub async fn check_predicate_constraints(
solution: Arc<Solution>,
solution_data_index: SolutionDataIndex,
predicate: Arc<Predicate>,
pre_slots: Arc<StateSlotSlice>,
post_slots: Arc<StateSlotSlice>,
config: &CheckPredicateConfig,
transient_data: Arc<TransientData>,
) -> Result<Utility, PredicateConstraintsError> {
match check_predicate_constraints_parallel(
solution.clone(),
solution_data_index,
predicate.clone(),
pre_slots.clone(),
post_slots.clone(),
config,
transient_data.clone(),
)
.await
{
Ok(()) => {
#[cfg(feature = "tracing")]
tracing::trace!("constraint check complete");
match calculate_utility(
solution,
solution_data_index,
predicate.clone(),
pre_slots,
post_slots,
transient_data,
)
.await
{
Ok(util) => {
#[cfg(feature = "tracing")]
tracing::trace!("utility: {}", util);
Ok(util)
}
Err(err) => {
#[cfg(feature = "tracing")]
tracing::trace!("error calculating utility: {}", err);
Err(err.into())
}
}
}
Err(err) => {
#[cfg(feature = "tracing")]
tracing::trace!("error checking constraints: {}", err);
Err(err)
}
}
}
async fn check_predicate_constraints_parallel(
solution: Arc<Solution>,
solution_data_index: SolutionDataIndex,
predicate: Arc<Predicate>,
pre_slots: Arc<StateSlotSlice>,
post_slots: Arc<StateSlotSlice>,
config: &CheckPredicateConfig,
transient_data: Arc<TransientData>,
) -> Result<(), PredicateConstraintsError> {
let mut handles = Vec::with_capacity(predicate.constraints.len());
for ix in 0..predicate.constraints.len() {
let (tx, rx) = tokio::sync::oneshot::channel();
handles.push(rx);
let solution = solution.clone();
let transient_data = transient_data.clone();
let pre_slots = pre_slots.clone();
let post_slots = post_slots.clone();
let predicate = predicate.clone();
#[cfg(feature = "tracing")]
let span = tracing::Span::current();
rayon::spawn(move || {
#[cfg(feature = "tracing")]
let span = tracing::trace_span!(parent: &span, "constraint", ix = ix as u32);
#[cfg(feature = "tracing")]
let guard = span.enter();
let mutable_keys = constraint_vm::mut_keys_set(&solution, solution_data_index);
let solution_access = SolutionAccess::new(
&solution,
solution_data_index,
&mutable_keys,
&transient_data,
);
let access = Access {
solution: solution_access,
state_slots: StateSlots {
pre: &pre_slots,
post: &post_slots,
},
};
let res = constraint_vm::eval_bytecode_iter(
predicate
.constraints
.get(ix)
.expect("Safe due to above len check")
.iter()
.copied(),
access,
);
let _ = tx.send((ix, res));
#[cfg(feature = "tracing")]
drop(guard)
})
}
let mut failed = Vec::new();
let mut unsatisfied = Vec::new();
for handle in handles {
let (ix, res): (usize, Result<bool, _>) = handle.await?;
match res {
Err(err) => {
failed.push((ix, err));
if !config.collect_all_failures {
break;
}
}
Ok(b) if !b => unsatisfied.push(ix),
_ => (),
}
}
if !failed.is_empty() {
return Err(CheckError::from(ConstraintErrors(failed)).into());
}
if !unsatisfied.is_empty() {
return Err(CheckError::from(ConstraintsUnsatisfied(unsatisfied)).into());
}
Ok(())
}
async fn calculate_utility(
solution: Arc<Solution>,
solution_data_index: SolutionDataIndex,
predicate: Arc<Predicate>,
pre_slots: Arc<StateSlotSlice>,
post_slots: Arc<StateSlotSlice>,
transient_data: Arc<TransientData>,
) -> Result<Utility, UtilityError> {
match &predicate.directive {
Directive::Satisfy => return Ok(1.0),
Directive::Maximize(_) | Directive::Minimize(_) => (),
}
let (tx, rx) = tokio::sync::oneshot::channel();
#[cfg(feature = "tracing")]
let span = tracing::Span::current();
rayon::spawn(move || {
#[cfg(feature = "tracing")]
let span = tracing::trace_span!(parent: &span, "utility");
#[cfg(feature = "tracing")]
let guard = span.enter();
let mutable_keys = constraint_vm::mut_keys_set(&solution, solution_data_index);
let solution_access = SolutionAccess::new(
&solution,
solution_data_index,
&mutable_keys,
&transient_data,
);
let access = Access {
solution: solution_access,
state_slots: StateSlots {
pre: &pre_slots,
post: &post_slots,
},
};
let code = match predicate.directive {
Directive::Maximize(ref code) | Directive::Minimize(ref code) => code,
_ => unreachable!("As this is already checked above"),
};
let res = constraint_vm::exec_bytecode_iter(code.iter().copied(), access)
.map_err(UtilityError::from)
.and_then(|mut stack| {
let [start, end, value] = stack.pop3()?;
let util = normalize_utility(value, start, end)?;
Ok(util)
});
let _ = tx.send(res);
#[cfg(feature = "tracing")]
drop(guard)
});
rx.await?
}
fn normalize_utility(value: Word, start: Word, end: Word) -> Result<Utility, UtilityError> {
if start >= end {
return Err(UtilityError::InvalidDirectiveRange(start, end));
}
let normalized = (value - start) as f64 / (end - start) as f64;
Ok(normalized.clamp(0.0, 1.0))
}