mod custom;
use std::borrow::Cow;
use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher};
use super::{NamedOp, OpName, OpTrait, StaticTag};
use super::{OpTag, OpType};
use crate::extension::ExtensionSet;
use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow};
use crate::{Hugr, HugrView};
use delegate::delegate;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use thiserror::Error;
pub use custom::{
downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
CustomSerialized, TryHash,
};
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Const {
#[serde(rename = "v")]
pub value: Value,
}
impl Const {
pub fn new(value: Value) -> Self {
Self { value }
}
pub fn value(&self) -> &Value {
&self.value
}
delegate! {
to self.value {
pub fn get_type(&self) -> Type;
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T>;
pub fn validate(&self) -> Result<(), ConstTypeError>;
}
}
}
impl From<Value> for Const {
fn from(value: Value) -> Self {
Self::new(value)
}
}
impl NamedOp for Const {
fn name(&self) -> OpName {
self.value().name()
}
}
impl StaticTag for Const {
const TAG: OpTag = OpTag::Const;
}
impl OpTrait for Const {
fn description(&self) -> &str {
"Constant value"
}
fn extension_delta(&self) -> ExtensionSet {
self.value().extension_reqs()
}
fn tag(&self) -> OpTag {
<Self as StaticTag>::TAG
}
fn static_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::Const(self.get_type()))
}
}
impl From<Const> for Value {
fn from(konst: Const) -> Self {
konst.value
}
}
impl AsRef<Value> for Const {
fn as_ref(&self) -> &Value {
self.value()
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
struct SerialSum {
#[serde(default)]
tag: usize,
#[serde(rename = "vs")]
values: Vec<Value>,
#[serde(default, rename = "typ")]
sum_type: Option<SumType>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(try_from = "SerialSum")]
#[serde(into = "SerialSum")]
pub struct Sum {
pub tag: usize,
pub values: Vec<Value>,
pub sum_type: SumType,
}
impl Sum {
pub fn as_tuple(&self) -> Option<&[Value]> {
self.sum_type.as_tuple().map(|_| self.values.as_ref())
}
fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
maybe_hash_values(&self.values, st) && {
st.write_usize(self.tag);
self.sum_type.hash(st);
true
}
}
}
pub(crate) fn maybe_hash_values<H: Hasher>(vals: &[Value], st: &mut H) -> bool {
let mut hasher = DefaultHasher::new();
vals.iter().all(|e| e.try_hash(&mut hasher)) && {
st.write_u64(hasher.finish());
true
}
}
impl TryFrom<SerialSum> for Sum {
type Error = &'static str;
fn try_from(value: SerialSum) -> Result<Self, Self::Error> {
let SerialSum {
tag,
values,
sum_type,
} = value;
let sum_type = if let Some(sum_type) = sum_type {
sum_type
} else {
if tag != 0 {
return Err("Sum type must be provided if tag is not 0");
}
SumType::new_tuple(values.iter().map(Value::get_type).collect_vec())
};
Ok(Self {
tag,
values,
sum_type,
})
}
}
impl From<Sum> for SerialSum {
fn from(value: Sum) -> Self {
Self {
tag: value.tag,
values: value.values,
sum_type: Some(value.sum_type),
}
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
pub enum Value {
Extension {
#[serde(flatten)]
e: OpaqueValue,
},
Function {
hugr: Box<Hugr>,
},
#[serde(alias = "Tuple")]
Sum(Sum),
}
#[cfg_attr(not(miri), doc = "```")] #[cfg_attr(miri, doc = "```ignore")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpaqueValue {
#[serde(flatten, with = "self::custom::serde_extension_value")]
v: Box<dyn CustomConst>,
}
impl OpaqueValue {
pub fn new(cc: impl CustomConst) -> Self {
Self { v: Box::new(cc) }
}
pub fn value(&self) -> &dyn CustomConst {
self.v.as_ref()
}
pub(crate) fn value_mut(&mut self) -> &mut dyn CustomConst {
self.v.as_mut()
}
delegate! {
to self.value() {
pub fn get_type(&self) -> Type;
pub fn name(&self) -> ValueName;
pub fn extension_reqs(&self) -> ExtensionSet;
}
}
}
impl<CC: CustomConst> From<CC> for OpaqueValue {
fn from(x: CC) -> Self {
Self::new(x)
}
}
impl From<Box<dyn CustomConst>> for OpaqueValue {
fn from(value: Box<dyn CustomConst>) -> Self {
Self { v: value }
}
}
impl PartialEq for OpaqueValue {
fn eq(&self, other: &Self) -> bool {
self.value().equal_consts(other.value())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum CustomCheckFailure {
#[error("Expected type: {expected} but value was of type: {found}")]
TypeMismatch {
expected: CustomType,
found: Type,
},
#[error("{0}")]
Message(String),
}
#[derive(Clone, Debug, PartialEq, Error)]
#[non_exhaustive]
pub enum ConstTypeError {
#[error("{0}")]
SumType(#[from] SumTypeError),
#[error(
"A function constant cannot be defined using a Hugr with root of type {hugr_root_type}. Must be a monomorphic function.",
)]
NotMonomorphicFunction {
hugr_root_type: OpType,
},
#[error("Value {1:?} does not match expected type {0}")]
ConstCheckFail(Type, Value),
#[error("Error when checking custom type: {0}")]
CustomCheckFail(#[from] CustomCheckFailure),
}
fn mono_fn_type(h: &Hugr) -> Result<Cow<'_, Signature>, ConstTypeError> {
let err = || ConstTypeError::NotMonomorphicFunction {
hugr_root_type: h.root_type().clone(),
};
if let Some(pf) = h.poly_func_type() {
match pf.try_into() {
Ok(sig) => return Ok(Cow::Owned(sig)),
Err(_) => return Err(err()),
};
}
h.inner_function_type().ok_or_else(err)
}
impl Value {
pub fn get_type(&self) -> Type {
match self {
Self::Extension { e } => e.get_type(),
Self::Sum(Sum { sum_type, .. }) => sum_type.clone().into(),
Self::Function { hugr } => {
let func_type = mono_fn_type(hugr).unwrap_or_else(|e| panic!("{}", e));
Type::new_function(func_type.into_owned())
}
}
}
pub fn sum(
tag: usize,
items: impl IntoIterator<Item = Value>,
typ: SumType,
) -> Result<Self, ConstTypeError> {
let values: Vec<Value> = items.into_iter().collect();
typ.check_type(tag, &values)?;
Ok(Self::Sum(Sum {
tag,
values,
sum_type: typ,
}))
}
pub fn tuple(items: impl IntoIterator<Item = Value>) -> Self {
let vs = items.into_iter().collect_vec();
let tys = vs.iter().map(Self::get_type).collect_vec();
Self::sum(0, vs, SumType::new_tuple(tys)).expect("Tuple type is valid")
}
pub fn function(hugr: impl Into<Hugr>) -> Result<Self, ConstTypeError> {
let hugr = hugr.into();
mono_fn_type(&hugr)?;
Ok(Self::Function {
hugr: Box::new(hugr),
})
}
pub const fn unit() -> Self {
Self::Sum(Sum {
tag: 0,
values: vec![],
sum_type: SumType::Unit { size: 1 },
})
}
pub fn unit_sum(tag: usize, size: u8) -> Result<Self, ConstTypeError> {
Self::sum(tag, [], SumType::Unit { size })
}
pub fn unary_unit_sum() -> Self {
Self::unit_sum(0, 1).expect("0 < 1")
}
pub fn true_val() -> Self {
Self::unit_sum(1, 2).expect("1 < 2")
}
pub fn false_val() -> Self {
Self::unit_sum(0, 2).expect("0 < 2")
}
pub fn some<V: Into<Value>>(values: impl IntoIterator<Item = V>) -> Self {
let values: Vec<Value> = values.into_iter().map(Into::into).collect_vec();
let value_types: Vec<Type> = values.iter().map(|v| v.get_type()).collect_vec();
let sum_type = SumType::new_option(value_types);
Self::sum(1, values, sum_type).unwrap()
}
pub fn none(value_types: impl Into<TypeRow>) -> Self {
Self::sum(0, [], SumType::new_option(value_types)).unwrap()
}
pub fn from_bool(b: bool) -> Self {
if b {
Self::true_val()
} else {
Self::false_val()
}
}
pub fn extension(custom_const: impl CustomConst) -> Self {
Self::Extension {
e: OpaqueValue::new(custom_const),
}
}
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
if let Self::Extension { e } = self {
e.v.downcast_ref()
} else {
None
}
}
fn name(&self) -> OpName {
match self {
Self::Extension { e } => format!("const:custom:{}", e.name()),
Self::Function { hugr: h } => {
let Ok(t) = mono_fn_type(h) else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
Self::Sum(Sum {
tag,
values,
sum_type,
}) => {
if sum_type.as_tuple().is_some() {
let names: Vec<_> = values.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.iter().join(", "))
} else {
format!("const:sum:{{tag:{tag}, vals:{values:?}}}")
}
}
}
.into()
}
pub fn extension_reqs(&self) -> ExtensionSet {
match self {
Self::Extension { e } => e.extension_reqs().clone(),
Self::Function { .. } => ExtensionSet::new(), Self::Sum(Sum { values, .. }) => {
ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs()))
}
}
}
pub fn validate(&self) -> Result<(), ConstTypeError> {
match self {
Self::Extension { e } => Ok(e.value().validate()?),
Self::Function { hugr } => {
mono_fn_type(hugr)?;
Ok(())
}
Self::Sum(Sum {
tag,
values,
sum_type,
}) => {
sum_type.check_type(*tag, values)?;
Ok(())
}
}
}
pub fn as_tuple(&self) -> Option<&[Value]> {
if let Self::Sum(sum) = self {
sum.as_tuple()
} else {
None
}
}
pub fn try_hash<H: Hasher>(&self, st: &mut H) -> bool {
match self {
Value::Extension { e } => e.value().try_hash(&mut *st),
Value::Function { .. } => false,
Value::Sum(s) => s.try_hash(st),
}
}
}
impl<T> From<T> for Value
where
T: CustomConst,
{
fn from(value: T) -> Self {
Self::extension(value)
}
}
pub type ValueName = SmolStr;
pub type ValueNameRef = str;
#[cfg(test)]
pub(crate) mod test {
use std::collections::HashSet;
use std::sync::{Arc, Weak};
use super::Value;
use crate::builder::inout_sig;
use crate::builder::test::simple_dfg_hugr;
use crate::extension::prelude::{bool_t, usize_custom_t};
use crate::extension::resolution::{
resolve_custom_type_extensions, resolve_typearg_extensions, ExtensionResolutionError,
WeakExtensionRegistry,
};
use crate::extension::PRELUDE;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::std_extensions::collections::array::{array_type, ArrayValue};
use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::{
prelude::{usize_t, ConstUsize},
ExtensionId,
},
std_extensions::arithmetic::float_types::{float64_type, ConstF64},
type_row,
types::type_param::TypeArg,
types::{Type, TypeBound, TypeRow},
};
use cool_asserts::assert_matches;
use rstest::{fixture, rstest};
use super::*;
#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
pub(crate) struct CustomTestValue(pub CustomType);
#[typetag::serde]
impl CustomConst for CustomTestValue {
fn name(&self) -> ValueName {
format!("CustomTestValue({:?})", self.0).into()
}
fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(self.0.extension().clone())
}
fn update_extensions(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
resolve_custom_type_extensions(&mut self.0, extensions)?;
for arg in self.0.args_mut() {
resolve_typearg_extensions(arg, extensions)?;
}
Ok(())
}
fn get_type(&self) -> Type {
self.0.clone().into()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}
}
pub(crate) fn serialized_float(f: f64) -> Value {
CustomSerialized::try_from_custom_const(ConstF64::new(f))
.unwrap()
.into()
}
#[test]
fn test_sum() -> Result<(), BuildError> {
use crate::builder::Container;
let pred_rows = vec![vec![usize_t(), float64_type()].into(), Type::EMPTY_TYPEROW];
let pred_ty = SumType::new(pred_rows.clone());
let mut b = DFGBuilder::new(inout_sig(
type_row![],
TypeRow::from(vec![pred_ty.clone().into()]),
))?;
let usize_custom_t = usize_custom_t(&Arc::downgrade(&PRELUDE));
let c = b.add_constant(Value::sum(
0,
[
CustomTestValue(usize_custom_t.clone()).into(),
ConstF64::new(5.1).into(),
],
pred_ty.clone(),
)?);
let w = b.load_const(&c);
b.finish_hugr_with_outputs([w]).unwrap();
let mut b = DFGBuilder::new(Signature::new(
type_row![],
TypeRow::from(vec![pred_ty.clone().into()]),
))?;
let c = b.add_constant(Value::sum(1, [], pred_ty.clone())?);
let w = b.load_const(&c);
b.finish_hugr_with_outputs([w]).unwrap();
Ok(())
}
#[test]
fn test_bad_sum() {
let pred_ty = SumType::new([vec![usize_t(), float64_type()].into(), type_row![]]);
let good_sum = const_usize();
println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
let good_sum =
Value::sum(0, [const_usize(), serialized_float(5.1)], pred_ty.clone()).unwrap();
println!("{}", serde_json::to_string_pretty(&good_sum).unwrap());
let res = Value::sum(0, [], pred_ty.clone());
assert_matches!(
res,
Err(ConstTypeError::SumType(SumTypeError::WrongVariantLength {
tag: 0,
expected: 2,
found: 0
}))
);
let res = Value::sum(4, [], pred_ty.clone());
assert_matches!(
res,
Err(ConstTypeError::SumType(SumTypeError::InvalidTag {
tag: 4,
num_variants: 2
}))
);
let res = Value::sum(0, [const_usize(), const_usize()], pred_ty);
assert_matches!(
res,
Err(ConstTypeError::SumType(SumTypeError::InvalidValueType {
tag: 0,
index: 1,
expected,
found,
})) if expected == float64_type() && found == const_usize()
);
}
#[rstest]
fn function_value(simple_dfg_hugr: Hugr) {
let v = Value::function(simple_dfg_hugr).unwrap();
let correct_type = Type::new_function(Signature::new_endo(vec![bool_t()]));
assert_eq!(v.get_type(), correct_type);
assert!(v.name().starts_with("const:function:"))
}
#[fixture]
fn const_usize() -> Value {
ConstUsize::new(257).into()
}
#[fixture]
fn const_serialized_usize() -> Value {
CustomSerialized::try_from_custom_const(ConstUsize::new(257))
.unwrap()
.into()
}
#[fixture]
fn const_tuple() -> Value {
Value::tuple([const_usize(), Value::true_val()])
}
#[fixture]
fn const_tuple_serialized() -> Value {
Value::tuple([const_serialized_usize(), Value::true_val()])
}
#[fixture]
fn const_array_bool() -> Value {
ArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into()
}
#[fixture]
fn const_array_options() -> Value {
let some_true = Value::some([Value::true_val()]);
let none = Value::none(vec![bool_t()]);
let elem_ty = SumType::new_option(vec![bool_t()]);
ArrayValue::new(elem_ty.into(), [some_true, none]).into()
}
#[rstest]
#[case(Value::unit(), Type::UNIT, "const:seq:{}")]
#[case(const_usize(), usize_t(), "const:custom:ConstUsize(")]
#[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")]
#[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")]
#[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")]
#[case(
const_array_options(),
array_type(2, SumType::new_option(vec![bool_t()]).into()),
"const:custom:array"
)]
fn const_type(
#[case] const_value: Value,
#[case] expected_type: Type,
#[case] name_prefix: &str,
) {
assert_eq!(const_value.get_type(), expected_type);
let name = const_value.name();
assert!(
name.starts_with(name_prefix),
"{name} does not start with {name_prefix}"
);
}
#[rstest]
#[case(Value::unit(), Value::unit())]
#[case(const_usize(), const_usize())]
#[case(const_serialized_usize(), const_usize())]
#[case(const_tuple_serialized(), const_tuple())]
#[case(const_array_bool(), const_array_bool())]
#[case(const_array_options(), const_array_options())]
fn const_serde_roundtrip(#[case] const_value: Value, #[case] expected_value: Value) {
let serialized = serde_json::to_string(&const_value).unwrap();
let deserialized: Value = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, expected_value);
}
#[rstest]
fn const_custom_value(const_usize: Value, const_tuple: Value) {
assert_eq!(
const_usize.get_custom_value::<ConstUsize>(),
Some(&ConstUsize::new(257))
);
assert_eq!(const_usize.get_custom_value::<ConstInt>(), None);
assert_eq!(const_tuple.get_custom_value::<ConstUsize>(), None);
assert_eq!(const_tuple.get_custom_value::<ConstInt>(), None);
}
#[test]
fn test_json_const() {
let ex_id: ExtensionId = "my_extension".try_into().unwrap();
let typ_int = CustomType::new(
"my_type",
vec![TypeArg::BoundedNat { n: 8 }],
ex_id.clone(),
TypeBound::Copyable,
&Weak::default(),
);
let json_const: Value =
CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into();
let classic_t = Type::new_extension(typ_int.clone());
assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable);
assert_eq!(json_const.get_type(), classic_t);
let typ_qb = CustomType::new(
"my_type",
vec![],
ex_id,
TypeBound::Copyable,
&Weak::default(),
);
let t = Type::new_extension(typ_qb.clone());
assert_ne!(json_const.get_type(), t);
}
#[rstest]
fn hash_tuple(const_tuple: Value) {
let vals = [
Value::unit(),
Value::true_val(),
Value::false_val(),
ConstUsize::new(13).into(),
Value::tuple([ConstUsize::new(13).into()]),
Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(14).into()]),
Value::tuple([ConstUsize::new(13).into(), ConstUsize::new(15).into()]),
const_tuple,
];
let num_vals = vals.len();
let hashes = vals.map(|v| {
let mut h = DefaultHasher::new();
v.try_hash(&mut h).then_some(()).unwrap();
h.finish()
});
assert_eq!(HashSet::from(hashes).len(), num_vals); }
#[test]
fn unhashable_tuple() {
let tup = Value::tuple([ConstUsize::new(5).into(), ConstF64::new(4.97).into()]);
let mut h1 = DefaultHasher::new();
let r = tup.try_hash(&mut h1);
assert!(!r);
h1.write_usize(5);
let mut h2 = DefaultHasher::new();
h2.write_usize(5);
assert_eq!(h1.finish(), h2.finish());
}
mod proptest {
use super::super::{OpaqueValue, Sum};
use crate::{
ops::{constant::CustomSerialized, Value},
std_extensions::arithmetic::int_types::ConstInt,
std_extensions::collections::list::ListValue,
types::{SumType, Type},
};
use ::proptest::{collection::vec, prelude::*};
impl Arbitrary for OpaqueValue {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
prop_oneof![
any::<ConstInt>().prop_map_into(),
any::<CustomSerialized>().prop_map_into()
]
.prop_recursive(
3, 32, 3, |child_strat| {
(any::<Type>(), vec(child_strat, 0..3)).prop_map(|(typ, children)| {
Self::new(ListValue::new(
typ,
children.into_iter().map(|e| Value::Extension { e }),
))
})
},
)
.boxed()
}
}
impl Arbitrary for Value {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use ::proptest::collection::vec;
let leaf_strat = prop_oneof![
any::<OpaqueValue>().prop_map(|e| Self::Extension { e }),
crate::proptest::any_hugr().prop_map(|x| Value::function(x).unwrap())
];
leaf_strat
.prop_recursive(
3, 32, 3, |element| {
prop_oneof![
vec(element.clone(), 0..3).prop_map(Self::tuple),
(
any::<usize>(),
vec(element.clone(), 0..3),
any_with::<SumType>(1.into()) )
.prop_map(
|(tag, values, sum_type)| {
Self::Sum(Sum {
tag,
values,
sum_type,
})
}
),
]
},
)
.boxed()
}
}
}
#[test]
fn test_tuple_deserialize() {
let json = r#"
{
"v": "Tuple",
"vs": [
{
"v": "Sum",
"tag": 0,
"typ": {
"t": "Sum",
"s": "Unit",
"size": 1
},
"vs": []
},
{
"v": "Sum",
"tag": 1,
"typ": {
"t": "Sum",
"s": "General",
"rows": [
[
{
"t": "Sum",
"s": "Unit",
"size": 1
}
],
[
{
"t": "Sum",
"s": "Unit",
"size": 2
}
]
]
},
"vs": [
{
"v": "Sum",
"tag": 1,
"typ": {
"t": "Sum",
"s": "Unit",
"size": 2
},
"vs": []
}
]
}
]
}
"#;
let v: Value = serde_json::from_str(json).unwrap();
assert_eq!(
v,
Value::tuple([
Value::unit(),
Value::sum(
1,
[Value::true_val()],
SumType::new([
type_row![Type::UNIT],
vec![Value::true_val().get_type()].into()
]),
)
.unwrap()
])
);
}
}