use std::collections::hash_map::DefaultHasher; use std::convert::Infallible;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use hugr_core::Node;
use hugr_core::core::HugrNode;
use hugr_core::ops::Value;
use hugr_core::ops::constant::OpaqueValue;
use hugr_core::types::ConstTypeError;
use crate::passes::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum};
#[derive(Clone, Debug)]
pub struct HashedConst {
hash: u64,
pub(super) val: Arc<OpaqueValue>,
}
impl HashedConst {
pub(super) fn try_new(val: Arc<OpaqueValue>) -> Option<Self> {
let mut hasher = DefaultHasher::new();
val.value().try_hash(&mut hasher).then(|| HashedConst {
hash: hasher.finish(),
val,
})
}
}
impl PartialEq for HashedConst {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash && self.val.value().equal_consts(other.val.value())
}
}
impl Eq for HashedConst {}
impl Hash for HashedConst {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.hash);
}
}
#[derive(Clone, Debug)]
pub enum ValueHandle<N = Node> {
Hashable(HashedConst),
Unhashable {
node: N,
fields: Vec<usize>,
leaf: Arc<OpaqueValue>,
},
}
fn node_and_fields<N: HugrNode>(loc: &ConstLocation<N>) -> (N, Vec<usize>) {
match loc {
ConstLocation::Node(n) => (*n, vec![]),
ConstLocation::Field(idx, elem) => {
let (n, mut f) = node_and_fields(elem);
f.push(*idx);
(n, f)
}
}
}
impl<N: HugrNode> ValueHandle<N> {
pub fn new_opaque<'a>(loc: impl Into<ConstLocation<'a, N>>, val: OpaqueValue) -> Self
where
N: 'a,
{
let arc = Arc::new(val);
let (node, fields) = node_and_fields(&loc.into());
HashedConst::try_new(arc.clone()).map_or(
Self::Unhashable {
node,
fields,
leaf: arc,
},
Self::Hashable,
)
}
}
impl<N: HugrNode> AbstractValue for ValueHandle<N> {}
impl<N: HugrNode> PartialEq for ValueHandle<N> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Hashable(h1), Self::Hashable(h2)) => h1 == h2,
(
Self::Unhashable {
node: n1,
fields: f1,
leaf: _,
},
Self::Unhashable {
node: n2,
fields: f2,
leaf: _,
},
) => {
n1 == n2 && f1 == f2
}
_ => false,
}
}
}
impl<N: HugrNode> Eq for ValueHandle<N> {}
impl<N: HugrNode> Hash for ValueHandle<N> {
fn hash<I: Hasher>(&self, state: &mut I) {
match self {
ValueHandle::Hashable(hc) => hc.hash(state),
ValueHandle::Unhashable {
node,
fields,
leaf: _,
} => {
node.hash(state);
fields.hash(state);
}
}
}
}
impl<N: HugrNode> AsConcrete<ValueHandle<N>, N> for Value {
type ValErr = Infallible;
type SumErr = ConstTypeError;
fn from_value(value: ValueHandle<N>) -> Result<Self, Infallible> {
Ok(match value {
ValueHandle::Hashable(HashedConst { val, .. })
| ValueHandle::Unhashable { leaf: val, .. } => Value::Extension {
e: Arc::try_unwrap(val).unwrap_or_else(|a| a.as_ref().clone()),
},
})
}
fn from_sum(value: Sum<Self>) -> Result<Self, Self::SumErr> {
Self::sum(value.tag, value.values, value.st)
}
fn from_func(
func: LoadedFunction<N>,
) -> Result<Self, crate::passes::dataflow::LoadedFunction<N>> {
Err(func)
}
}
#[cfg(test)]
mod test {
use hugr_core::{
extension::prelude::ConstString,
std_extensions::{
arithmetic::{
float_types::{ConstF64, float64_type},
int_types::{ConstInt, INT_TYPES},
},
collections::list::ListValue,
},
};
use super::*;
#[test]
fn value_key_eq() {
let n = Node::from(portgraph::NodeIndex::new(0));
let n2: Node = portgraph::NodeIndex::new(1).into();
let h1 = ValueHandle::new_opaque(n, ConstString::new("foo".to_string()).into());
let h2 = ValueHandle::new_opaque(n2, ConstString::new("foo".to_string()).into());
let h3 = ValueHandle::new_opaque(n, ConstString::new("bar".to_string()).into());
assert_eq!(h1, h2); assert_ne!(h1, h3);
let f = ConstF64::new(std::f64::consts::PI);
let h4 = ValueHandle::new_opaque(n, f.clone().into());
assert_ne!(h4, h1);
assert_ne!(h1, h4);
let f2 = ConstF64::new(std::f64::consts::E);
assert_eq!(h4, ValueHandle::new_opaque(n, f2.clone().into()));
assert_ne!(
h4,
ValueHandle::new_opaque(ConstLocation::Field(5, &n.into()), f2.into())
);
}
#[test]
fn value_key_list() {
let v1 = ConstInt::new_u(3, 3).unwrap();
let v2 = ConstInt::new_u(4, 3).unwrap();
let v3 = ConstF64::new(std::f64::consts::PI);
let n = Node::from(portgraph::NodeIndex::new(0));
let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]);
assert_eq!(
ValueHandle::new_opaque(n, lst.clone().into()),
ValueHandle::new_opaque(ConstLocation::Field(1, &n.into()), lst.into())
);
let lst = ListValue::new(float64_type(), [v3.into()]);
assert_ne!(
ValueHandle::new_opaque(n, lst.clone().into()),
ValueHandle::new_opaque(ConstLocation::Field(3, &n.into()), lst.into())
);
}
}