use std::collections::HashMap;
use itertools::Itertools;
use rtlola_streamir::{
formatter::{files::FilesFormatter, statements::StmtFormatter, types::TypeFormatter},
ir::{InputReference, Stmt, StreamIr, StreamReference},
optimize,
rewrite_rules::EvaluateGuards,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{
functions::{FunctionDefinition, OutputsVerdict, Visibility},
RequirementKey, SolidityFormatter,
};
#[derive(Error, Debug)]
pub enum PartialEvaluationError {
#[error("The stream {0} is expected to be an input stream, but was an output stream.")]
ExpectedInput(String),
#[error("The stream {0} does not exist in the specification.")]
NotExisting(String),
}
pub(crate) fn partial_evaluation(
config: &InterfaceConfig,
lir: &StreamIr,
) -> Result<Vec<(InterfaceFunction, Stmt)>, PartialEvaluationError> {
let functions = config.check(lir)?;
let functions_with_stmts = functions
.into_iter()
.map(|function| {
let inputs = function.inputs();
let stmt = optimize(
lir.clone(),
vec![Box::new(EvaluateGuards::only_streams(inputs))],
)
.unwrap()
.stmt;
(function, stmt)
})
.collect::<_>();
Ok(functions_with_stmts)
}
impl FunctionDefinition for (InterfaceFunction, Stmt) {
fn name(&self, _f: &SolidityFormatter) -> String {
self.0.name.clone()
}
fn header(&self, f: &SolidityFormatter) -> String {
format!(
"{}({})",
self.name(f),
self.0
.arguments
.iter()
.filter(|arg| arg.parameter)
.map(|arg| format!(
"{} {}",
f.ty(f.stream_type(StreamReference::In(arg.stream_ref)).clone()),
arg.name,
))
.join(", ")
)
}
fn body(self, f: &SolidityFormatter) -> String {
let code_inputs = self
.0
.arguments
.iter()
.filter(|arg| !arg.parameter)
.map(|arg| {
format!(
"{} {} = {};",
f.ty(f.stream_type(StreamReference::In(arg.stream_ref)).clone()),
arg.name,
arg.source
)
})
.join("\n");
vec![
Some(code_inputs),
Some(f.stmt(self.1.clone())),
(!f.outputs_verdict.is_empty()).then(|| OutputsVerdict.emit(f)),
]
.into_iter()
.flatten()
.join("\n")
}
fn key(&self) -> RequirementKey {
RequirementKey::Cycle(self.0.name.clone())
}
fn visibility(&self) -> Visibility {
Visibility::Public
}
fn payable(&self) -> bool {
self.0.tags.contains(&"payable".to_string())
}
}
impl SolidityFormatter {
pub(crate) fn format_cycle_functions(
&self,
config: &InterfaceConfig,
lir: StreamIr,
) -> Result<(), PartialEvaluationError> {
let functions = partial_evaluation(config, &lir)?;
for function in functions {
self.add_requirement((function,));
}
Ok(())
}
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct InterfaceConfig {
function: Vec<FunctionConfig>,
#[serde(default = "Vec::new")]
additional_input: Vec<AdditionalInput>,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct FunctionConfig {
name: String,
stream_name: Option<String>,
#[serde(default = "Vec::new")]
argument: Vec<ParameterConfig>,
#[serde(default = "Vec::new")]
additional_input: Vec<AdditionalInput>,
#[serde(default)]
tags: Vec<String>,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct ParameterConfig {
name: String,
stream_name: Option<String>,
code: Option<String>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct AdditionalInput {
input: String,
source: String,
parameter: bool,
}
#[derive(Debug, Clone)]
pub struct InterfaceFunction {
pub name: String,
pub arguments: Vec<InputSource>,
pub tags: Vec<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct InputSource {
pub source: String,
pub stream_ref: InputReference,
pub name: String,
pub parameter: bool,
}
impl InterfaceConfig {
pub fn from_toml(toml: &str) -> Result<Self, String> {
toml::from_str(toml).map_err(|e| e.to_string())
}
}
impl InterfaceConfig {
fn check(&self, lir: &StreamIr) -> Result<Vec<InterfaceFunction>, PartialEvaluationError> {
let InterfaceConfig {
function,
additional_input,
} = self;
let functions = function
.iter()
.map(|function| {
let function = function.check(additional_input, lir)?;
Ok(function)
})
.collect::<Result<Vec<InterfaceFunction>, _>>()?;
Ok(functions)
}
}
impl InterfaceFunction {
fn inputs(&self) -> Vec<StreamReference> {
self.arguments
.iter()
.map(|argument| StreamReference::In(argument.stream_ref))
.collect()
}
}
impl FunctionConfig {
fn check(
&self,
global_additional_inputs: &[AdditionalInput],
ir: &StreamIr,
) -> Result<InterfaceFunction, PartialEvaluationError> {
let FunctionConfig {
name,
stream_name,
argument,
additional_input,
tags,
} = self;
let arguments: Vec<_> = argument
.iter()
.map(|argument| argument.to_input(ir))
.collect::<Result<Vec<InputSource>, _>>()?;
let additional_inputs: Vec<_> = global_additional_inputs
.iter()
.chain(additional_input)
.map(|input| input.check(ir))
.collect::<Result<Vec<InputSource>, _>>()?;
let inputs: Vec<_> = arguments.iter().cloned().chain(additional_inputs).collect();
let mut inputs = inputs.iter().fold(HashMap::new(), |mut inputs, argument| {
inputs.insert(argument.stream_ref, argument.source.clone());
inputs
});
if let Some(stream_name) = stream_name {
let idx = match ir.stream_by_name(stream_name) {
Some(StreamReference::In(idx)) => idx,
Some(StreamReference::Out(_)) => {
return Err(PartialEvaluationError::ExpectedInput(stream_name.clone()))
}
None => return Err(PartialEvaluationError::NotExisting(stream_name.clone())),
};
inputs.insert(idx, "()".into());
}
Ok(InterfaceFunction {
name: name.into(),
arguments,
tags: tags.clone(),
})
}
}
impl ParameterConfig {
fn to_input(&self, lir: &StreamIr) -> Result<InputSource, PartialEvaluationError> {
let ParameterConfig {
name,
stream_name,
code,
} = self;
let stream_name = stream_name.as_ref().unwrap_or(name).clone();
AdditionalInput {
input: stream_name,
source: code.as_ref().cloned().unwrap_or_else(|| name.clone()),
parameter: code.is_none(),
}
.check(lir)
}
}
impl AdditionalInput {
fn check(&self, lir: &StreamIr) -> Result<InputSource, PartialEvaluationError> {
let sr = match lir.stream_by_name(&self.input) {
Some(StreamReference::Out(_)) => {
return Err(PartialEvaluationError::ExpectedInput(self.input.clone()))
}
Some(StreamReference::In(idx)) => idx,
None => return Err(PartialEvaluationError::NotExisting(self.input.clone())),
};
Ok(InputSource {
source: self.source.clone(),
stream_ref: sr,
name: self.input.clone(),
parameter: self.parameter,
})
}
}