use crate::data_types::Type;
use crate::data_values::Value;
use crate::errors::Result;
use crate::evaluators::Evaluator;
use crate::graphs::{copy_node_name, Graph, Node, Operation};
use std::cmp::Eq;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Clone)]
struct ConstantKey {
t: Type,
v: Value,
}
impl Hash for ConstantKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.t.hash(state);
self.v.deep_hash(state);
}
}
impl PartialEq for ConstantKey {
fn eq(&self, other: &Self) -> bool {
self.t == other.t && self.v == other.v
}
}
impl Eq for ConstantKey {}
pub(super) fn optimize_graph_constants(
graph: Graph,
out_graph: Graph,
evaluator: &mut dyn Evaluator,
) -> Result<()> {
graph.check_finalized()?;
let mut constant_cache = HashMap::<ConstantKey, Node>::new();
let mut node_mapping = HashMap::<Node, Node>::new();
let mut constant_nodes = HashMap::<Node, Value>::new();
for node in graph.get_nodes() {
if !node.get_graph_dependencies().is_empty() {
return Err(runtime_error!(
"Constant optimization works only on fully inlined graphs."
));
}
let mut resolve_const = |t: Type, val: Value, name: Option<String>| -> Result<Node> {
let key = ConstantKey {
t: t.clone(),
v: val.clone(),
};
if let std::collections::hash_map::Entry::Vacant(e) = constant_cache.entry(key.clone())
{
let constant_node = out_graph.constant(t, val)?;
if let Some(name) = name {
constant_node.set_name(&name)?;
}
e.insert(constant_node.clone());
Ok(constant_node)
} else {
Ok(constant_cache.get(&key).unwrap().clone())
}
};
let op = node.get_operation();
let new_node = match op {
Operation::Constant(t, val) => {
if !node.get_annotations()?.is_empty() {
return Err(runtime_error!(
"Constant optimization with annotations on const nodes in not supported"
));
}
let value_ptr = evaluator.evaluate_node(node.clone(), vec![])?;
constant_nodes.insert(node.clone(), value_ptr);
resolve_const(t, val, node.get_name()?)?
}
op => {
let mut deps = vec![];
let mut is_const_node = op.is_const_optimizable()?;
for dep in node.get_node_dependencies() {
let resolved_dep = node_mapping.get(&dep);
match resolved_dep {
Some(resolved_dep_node) => deps.push(resolved_dep_node.clone()),
None => {
panic!("Logic error: unprocessed node in dependencies");
}
};
if !constant_nodes.contains_key(&dep) {
is_const_node = false;
}
}
if is_const_node && node.get_annotations()?.is_empty() {
let dep_vals: Vec<Value> = node
.get_node_dependencies()
.iter()
.map(|dep| constant_nodes.get(dep).unwrap().clone())
.collect();
let value_ptr = evaluator.evaluate_node(node.clone(), dep_vals)?;
constant_nodes.insert(node.clone(), value_ptr.clone());
resolve_const(node.get_type()?, value_ptr.clone(), node.get_name()?)?
} else {
let result = out_graph.add_node(deps, vec![], node.get_operation())?;
for annotation in node.get_annotations()? {
result.add_annotation(annotation)?;
}
copy_node_name(node.clone(), result.clone())?;
result
}
}
};
if node == graph.get_output_node()? {
new_node.set_as_output()?;
}
node_mapping.insert(node, new_node);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data_types::{array_type, scalar_type, UINT64};
use crate::evaluators::simple_evaluator::SimpleEvaluator;
use crate::graphs::create_context;
use crate::graphs::util::simple_context;
use crate::graphs::{contexts_deep_equal, Context};
fn optimize_context(c: &Context) -> Result<Context> {
let mut evaluator = SimpleEvaluator::new(None)?;
evaluator.preprocess(c.clone())?;
let new_c = create_context()?;
let new_g = new_c.create_graph()?;
optimize_graph_constants(c.get_main_graph()?.clone(), new_g.clone(), &mut evaluator)?;
new_g.finalize()?;
new_g.set_as_main()?;
new_c.finalize()?;
Ok(new_c)
}
#[test]
fn test_no_duplicates() {
|| -> Result<()> {
let c = simple_context(|g| {
let i1 = g.input(scalar_type(UINT64))?;
let i2 = g.input(scalar_type(UINT64))?;
let n = i1.add(i2)?;
n.add(g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?)
})?;
assert!(contexts_deep_equal(optimize_context(&c)?, c));
Ok(())
}()
.unwrap();
}
#[test]
fn test_random_is_not_removed() {
|| -> Result<()> {
let c = simple_context(|g| {
let i1 = g.input(scalar_type(UINT64))?;
let i2 = g.input(scalar_type(UINT64))?;
let n = i1.add(i2)?;
let r1 = g.random(scalar_type(UINT64))?;
let r2 = g.random_permutation(5)?;
let r3 = r2.cuckoo_to_permutation()?;
let r4 = r2.decompose_switching_map(5)?;
let o1 = n.add(g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?)?;
g.create_tuple(vec![o1.add(r1)?, r3, r4])
})?;
assert!(contexts_deep_equal(optimize_context(&c)?, c));
Ok(())
}()
.unwrap();
}
#[test]
fn test_zeros() {
|| -> Result<()> {
let c = simple_context(|g| g.zeros(array_type(vec![1000, 1000], UINT64)))?;
assert!(contexts_deep_equal(optimize_context(&c)?, c));
Ok(())
}()
.unwrap();
}
#[test]
fn test_ones() {
|| -> Result<()> {
let c = simple_context(|g| g.ones(array_type(vec![1000, 1000], UINT64)))?;
assert!(contexts_deep_equal(optimize_context(&c)?, c));
Ok(())
}()
.unwrap();
}
#[test]
fn test_constants_simple_deduplication() {
|| -> Result<()> {
let c = simple_context(|g| {
let i = g.input(scalar_type(UINT64))?;
let const1 = g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?;
const1.set_name("First constant 1")?;
let n1 = i.add(const1)?;
let const2 = g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?;
const2.set_name("Second constant 1")?;
let n2 = n1.add(const2)?;
let n3 =
n2.add(g.constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)?)?;
let n4 =
n3.add(g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?)?;
n4.add(g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?)
})?;
let new_c = optimize_context(&c)?;
assert!(!contexts_deep_equal(new_c.clone(), c));
let new_o = new_c.get_main_graph()?.get_output_node()?;
let two1 = new_o.get_node_dependencies()[1].clone();
let new_n4 = new_o.get_node_dependencies()[0].clone();
let two2 = new_n4.get_node_dependencies()[1].clone();
let new_n3 = new_n4.get_node_dependencies()[0].clone();
let one1 = new_n3.get_node_dependencies()[1].clone();
let new_n2 = new_n3.get_node_dependencies()[0].clone();
let one2 = new_n2.get_node_dependencies()[1].clone();
let new_n1 = new_n2.get_node_dependencies()[0].clone();
let one3 = new_n1.get_node_dependencies()[1].clone();
assert!(one1 == one2);
assert!(one1 == one3);
assert!(one1 != two1);
assert!(two1 == two2);
let new_const1 = new_c.retrieve_node(new_c.get_main_graph()?, "First constant 1");
assert!(new_const1.is_ok());
assert_eq!(
new_const1?.get_operation(),
Operation::Constant(scalar_type(UINT64), Value::from_scalar(1, UINT64)?)
);
let new_const2 = new_c.retrieve_node(new_c.get_main_graph()?, "Second constant 1");
assert!(new_const2.is_err());
Ok(())
}()
.unwrap();
}
#[test]
fn test_constants_simple_arithmetic() {
|| -> Result<()> {
let c = simple_context(|g| {
let i = g.input(scalar_type(UINT64))?;
let const1 = g.constant(scalar_type(UINT64), Value::from_scalar(4, UINT64)?)?;
const1.set_name("First constant")?;
let n1 = i.add(const1)?;
let const2 = g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?;
let const3 = g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?;
let const4 = const2.add(const3)?;
const4.set_name("Fourth constant")?;
n1.add(const4)
})?;
let new_c = optimize_context(&c)?;
assert!(!contexts_deep_equal(new_c.clone(), c));
let new_o = new_c.get_main_graph()?.get_output_node()?;
let four1 = new_o.get_node_dependencies()[1].clone();
let new_n1 = new_o.get_node_dependencies()[0].clone();
let four2 = new_n1.get_node_dependencies()[1].clone();
assert!(four1 == four2);
let new_const1 = new_c.retrieve_node(new_c.get_main_graph()?, "First constant");
assert!(new_const1.is_ok());
assert_eq!(
new_const1?.get_operation(),
Operation::Constant(scalar_type(UINT64), Value::from_scalar(4, UINT64)?)
);
let new_const4 = new_c.retrieve_node(new_c.get_main_graph()?, "Fourth constant");
assert!(new_const4.is_err());
Ok(())
}()
.unwrap();
|| -> Result<()> {
let c = simple_context(|g| {
let i = g.input(scalar_type(UINT64))?;
let const1 = g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?;
let const2 = g.constant(scalar_type(UINT64), Value::from_scalar(2, UINT64)?)?;
let const3 = const1.add(const2)?;
const3.set_name("First constant")?;
let n = i.add(const3)?;
let const4 = g.constant(scalar_type(UINT64), Value::from_scalar(4, UINT64)?)?;
const4.set_name("Fourth constant")?;
n.add(const4)
})?;
let new_c = optimize_context(&c)?;
assert!(!contexts_deep_equal(new_c.clone(), c));
let new_o = new_c.get_main_graph()?.get_output_node()?;
let four1 = new_o.get_node_dependencies()[1].clone();
let new_n = new_o.get_node_dependencies()[0].clone();
let four2 = new_n.get_node_dependencies()[1].clone();
assert!(four1 == four2);
let new_const1 = new_c.retrieve_node(new_c.get_main_graph()?, "First constant");
assert!(new_const1.is_ok());
assert_eq!(
new_const1?.get_operation(),
Operation::Constant(scalar_type(UINT64), Value::from_scalar(4, UINT64)?)
);
let new_const4 = new_c.retrieve_node(new_c.get_main_graph()?, "Fourth constant");
assert!(new_const4.is_err());
Ok(())
}()
.unwrap();
}
}