use crate::errors::Result;
use crate::evaluators::Evaluator;
use crate::graphs::{create_context, Context, Graph, Operation};
use crate::optimizer::constant_optimizer::optimize_graph_constants;
use crate::optimizer::dangling_nodes_optimizer::optimize_graph_dangling_nodes;
use crate::optimizer::meta_operation_optimizer::optimize_graph_meta_operations;
use crate::random::PRNG;
pub fn optimize_context<T: Evaluator>(context: Context, mut evaluator: T) -> Result<Context> {
context.check_finalized()?;
evaluator.preprocess(context.clone())?;
let output_context = create_context()?;
for graph in context.get_graphs() {
let (_const_context, const_graph) = graph_in_new_context(graph.clone())?;
optimize_graph_constants(graph.clone(), const_graph.clone(), &mut evaluator)?;
const_graph.finalize()?;
let (_meta_context, meta_graph) = graph_in_new_context(graph.clone())?;
optimize_graph_meta_operations(const_graph.clone(), meta_graph.clone())?;
meta_graph.finalize()?;
let final_graph = add_graph_to_context(output_context.clone(), graph.clone())?;
optimize_graph_dangling_nodes(meta_graph.clone(), final_graph.clone())?;
final_graph.finalize()?;
if graph == context.get_main_graph()? {
final_graph.set_as_main()?;
}
}
output_context.finalize()?;
Ok(output_context)
}
fn add_graph_to_context(context: Context, source_graph: Graph) -> Result<Graph> {
let new_graph = context.create_graph()?;
for annotation in source_graph.get_annotations()? {
new_graph.add_annotation(annotation)?;
}
Ok(new_graph)
}
fn graph_in_new_context(source_graph: Graph) -> Result<(Context, Graph)> {
let context = create_context()?;
let graph = add_graph_to_context(context.clone(), source_graph)?;
Ok((context, graph))
}
#[doc(hidden)]
pub fn stress_test<T: Evaluator>(
c1: Context,
c2: Context,
ip_evaluator1: T,
ip_evaluator2: T,
) -> Result<()> {
let seed = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
let mut prng = PRNG::new(Some(*seed)).unwrap();
let mut evaluator1 = ip_evaluator1;
evaluator1.preprocess(c1.clone()).unwrap();
let mut evaluator2 = ip_evaluator2;
evaluator2.preprocess(c2.clone()).unwrap();
for _ in 0..10 {
let mut inputs = vec![];
for node in c1.get_main_graph().unwrap().get_nodes() {
if let Operation::Input(t) = node.get_operation() {
inputs.push(prng.get_random_value(t).unwrap());
}
}
let result1 = evaluator1
.evaluate_graph(c1.get_main_graph().unwrap(), inputs.clone())
.unwrap();
let result2 = evaluator2
.evaluate_graph(c2.get_main_graph().unwrap(), inputs.clone())
.unwrap();
assert_eq!(result1, result2);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::applications::sorting::create_batchers_sorting_graph;
use crate::custom_ops::run_instantiation_pass;
use crate::data_types::{scalar_type, UINT64};
use crate::data_values::Value;
use crate::evaluators::simple_evaluator::SimpleEvaluator;
use crate::graphs::create_context;
use crate::inline::inline_ops::{inline_operations, InlineConfig, InlineMode};
#[test]
fn test_simple() {
|| -> Result<()> {
let c = create_context()?;
let g = c.create_graph()?;
let i1 = g.input(scalar_type(UINT64))?;
let i2 = g.input(scalar_type(UINT64))?;
let n = i1.add(i2)?;
let o = n.add(g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?)?;
o.set_as_output()?;
g.finalize()?;
g.set_as_main()?;
c.finalize()?;
let optimized_c = optimize_context(c.clone(), SimpleEvaluator::new(None)?)?;
stress_test(
c,
optimized_c,
SimpleEvaluator::new(None).unwrap(),
SimpleEvaluator::new(None).unwrap(),
)?;
Ok(())
}()
.unwrap();
}
#[test]
fn test_optimizing_sorting_graph() {
|| -> Result<()> {
let c = create_context()?;
let g = create_batchers_sorting_graph(c.clone(), 4, UINT64)?;
g.set_as_main()?;
c.finalize()?;
let c1 = run_instantiation_pass(c)?.get_context();
let c2 = inline_operations(
c1,
InlineConfig {
default_mode: InlineMode::Simple,
..Default::default()
},
)?;
let optimized_c = optimize_context(c2.clone(), SimpleEvaluator::new(None)?)?;
assert!(
optimized_c.get_main_graph()?.get_nodes().len()
< c2.get_main_graph()?.get_nodes().len()
);
stress_test(
c2,
optimized_c,
SimpleEvaluator::new(None).unwrap(),
SimpleEvaluator::new(None).unwrap(),
)?;
Ok(())
}()
.unwrap();
}
}