use std::cmp::min;
use std::collections::HashMap;
use std::collections::btree_map::Entry;
use std::fmt::{Debug, Formatter};
use std::sync::{Arc, Weak};
use serde_with::serde_as;
use super::{
ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet,
SignatureError,
};
use crate::Hugr;
use crate::envelope::serde_with::AsBinaryEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::package::Package;
use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
mod serialize_signature_func;
pub trait CustomSignatureFunc: Send + Sync {
fn compute_signature<'o, 'a: 'o>(
&'a self,
arg_values: &[TypeArg],
def: &'o OpDef,
) -> Result<PolyFuncTypeRV, SignatureError>;
fn static_params(&self) -> &[TypeParam];
}
pub trait SignatureFromArgs: Send + Sync {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError>;
fn static_params(&self) -> &[TypeParam];
}
impl<T: SignatureFromArgs> CustomSignatureFunc for T {
#[inline]
fn compute_signature<'o, 'a: 'o>(
&'a self,
arg_values: &[TypeArg],
_def: &'o OpDef,
) -> Result<PolyFuncTypeRV, SignatureError> {
SignatureFromArgs::compute_signature(self, arg_values)
}
#[inline]
fn static_params(&self) -> &[TypeParam] {
SignatureFromArgs::static_params(self)
}
}
pub trait ValidateTypeArgs: Send + Sync {
fn validate<'o, 'a: 'o>(
&self,
arg_values: &[TypeArg],
def: &'o OpDef,
) -> Result<(), SignatureError>;
}
pub trait ValidateJustArgs: Send + Sync {
fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError>;
}
impl<T: ValidateJustArgs> ValidateTypeArgs for T {
#[inline]
fn validate<'o, 'a: 'o>(
&self,
arg_values: &[TypeArg],
_def: &'o OpDef,
) -> Result<(), SignatureError> {
ValidateJustArgs::validate(self, arg_values)
}
}
pub trait CustomLowerFunc: Send + Sync {
fn try_lower(
&self,
name: &OpNameRef,
arg_values: &[TypeArg],
misc: &HashMap<String, serde_json::Value>,
available_extensions: &ExtensionSet,
) -> Option<Hugr>;
}
pub struct CustomValidator {
poly_func: PolyFuncTypeRV,
pub(crate) validate: Box<dyn ValidateTypeArgs>,
}
impl CustomValidator {
pub fn new(
poly_func: impl Into<PolyFuncTypeRV>,
validate: impl ValidateTypeArgs + 'static,
) -> Self {
Self {
poly_func: poly_func.into(),
validate: Box::new(validate),
}
}
pub(crate) fn poly_func(&self) -> &PolyFuncTypeRV {
&self.poly_func
}
pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
&mut self.poly_func
}
}
pub enum SignatureFunc {
PolyFuncType(PolyFuncTypeRV),
CustomValidator(CustomValidator),
MissingValidateFunc(PolyFuncTypeRV),
CustomFunc(Box<dyn CustomSignatureFunc>),
MissingComputeFunc,
}
impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
fn from(v: T) -> Self {
Self::CustomFunc(Box::new(v))
}
}
impl From<PolyFuncType> for SignatureFunc {
fn from(value: PolyFuncType) -> Self {
Self::PolyFuncType(value.into())
}
}
impl From<PolyFuncTypeRV> for SignatureFunc {
fn from(v: PolyFuncTypeRV) -> Self {
Self::PolyFuncType(v)
}
}
impl From<FuncValueType> for SignatureFunc {
fn from(v: FuncValueType) -> Self {
Self::PolyFuncType(v.into())
}
}
impl From<Signature> for SignatureFunc {
fn from(v: Signature) -> Self {
Self::PolyFuncType(FuncValueType::from(v).into())
}
}
impl From<CustomValidator> for SignatureFunc {
fn from(v: CustomValidator) -> Self {
Self::CustomValidator(v)
}
}
impl SignatureFunc {
fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
Ok(match self {
SignatureFunc::PolyFuncType(ts)
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
| SignatureFunc::MissingValidateFunc(ts) => ts.params(),
SignatureFunc::CustomFunc(func) => func.static_params(),
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
})
}
pub fn ignore_missing_validation(&mut self) {
if let SignatureFunc::MissingValidateFunc(ts) = self {
*self = SignatureFunc::PolyFuncType(ts.clone());
}
}
pub(crate) fn poly_func_type(&self) -> Option<&PolyFuncTypeRV> {
match self {
SignatureFunc::PolyFuncType(ts) | SignatureFunc::MissingValidateFunc(ts) => Some(ts),
SignatureFunc::CustomValidator(custom) => Some(custom.poly_func()),
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
}
}
pub fn compute_signature(
&self,
def: &OpDef,
args: &[TypeArg],
) -> Result<Signature, SignatureError> {
let temp: PolyFuncTypeRV; let (pf, args) = match &self {
SignatureFunc::CustomValidator(custom) => {
custom.validate.validate(args, def)?;
(&custom.poly_func, args)
}
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
check_term_types(static_args, static_params)?;
temp = func.compute_signature(static_args, def)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(ts) => (ts, args),
};
let res = pf.instantiate(args)?;
res.try_into()
}
}
impl Debug for SignatureFunc {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::CustomValidator(ts) => ts.poly_func.fmt(f),
Self::PolyFuncType(ts) => ts.fmt(f),
Self::CustomFunc { .. } => f.write_str("<custom sig>"),
Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
}
}
}
#[serde_as]
#[derive(serde::Serialize)]
#[serde(untagged)]
pub enum LowerFunc {
FixedHugr {
extensions: ExtensionSet,
#[serde_as(as = "Box<AsBinaryEnvelope>")]
#[serde(rename = "hugr")]
pkg: Box<Package>,
},
#[serde(skip)]
CustomFunc(Box<dyn CustomLowerFunc>),
}
pub fn deserialize_lower_funcs<'de, D>(deserializer: D) -> Result<Vec<LowerFunc>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[serde_as]
#[derive(serde::Deserialize)]
struct FixedHugrDeserializer {
pub extensions: ExtensionSet,
#[serde_as(as = "Box<AsBinaryEnvelope>")]
pub hugr: Box<Package>,
}
let funcs: Vec<FixedHugrDeserializer> = serde::Deserialize::deserialize(deserializer)?;
Ok(funcs
.into_iter()
.map(|f| LowerFunc::FixedHugr {
extensions: f.extensions,
pkg: f.hugr,
})
.collect())
}
impl Debug for LowerFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FixedHugr { .. } => write!(f, "FixedHugr"),
Self::CustomFunc(_) => write!(f, "<custom lower>"),
}
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct OpDef {
extension: ExtensionId,
#[serde(skip)]
extension_ref: Weak<Extension>,
name: OpName,
description: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
misc: HashMap<String, serde_json::Value>,
#[serde(with = "serialize_signature_func", flatten)]
signature_func: SignatureFunc,
#[serde(
default,
skip_serializing_if = "Vec::is_empty",
deserialize_with = "deserialize_lower_funcs"
)]
pub(crate) lower_funcs: Vec<LowerFunc>,
#[serde(skip)]
constant_folder: Option<Box<dyn ConstFold>>,
}
impl OpDef {
pub fn validate_args(
&self,
args: &[TypeArg],
var_decls: &[TypeParam],
) -> Result<(), SignatureError> {
let temp: PolyFuncTypeRV; let (pf, args) = match &self.signature_func {
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(custom) => {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
check_term_types(static_args, custom.static_params())?;
temp = custom.compute_signature(static_args, self)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc);
}
};
args.iter().try_for_each(|ta| ta.validate(var_decls))?;
check_term_types(args, pf.params())?;
Ok(())
}
pub fn compute_signature(&self, args: &[TypeArg]) -> Result<Signature, SignatureError> {
self.signature_func.compute_signature(self, args)
}
#[must_use]
pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
self.lower_funcs
.iter()
.filter_map(|f| match f {
LowerFunc::FixedHugr { extensions, pkg } => {
if available_extensions.is_superset(extensions) {
pkg.modules.first().cloned()
} else {
None
}
}
LowerFunc::CustomFunc(f) => {
f.try_lower(&self.name, args, &self.misc, available_extensions)
}
})
.next()
}
#[must_use]
pub fn name(&self) -> &OpName {
&self.name
}
#[must_use]
pub fn extension_id(&self) -> &ExtensionId {
&self.extension
}
#[must_use]
pub fn extension(&self) -> Weak<Extension> {
self.extension_ref.clone()
}
pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
&mut self.extension_ref
}
#[must_use]
pub fn description(&self) -> &str {
self.description.as_ref()
}
pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
self.signature_func.static_params()
}
pub(super) fn validate(&self) -> Result<(), SignatureError> {
if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
ts.poly_func.validate()?;
}
Ok(())
}
pub fn add_lower_func(&mut self, lower: LowerFunc) {
self.lower_funcs.push(lower);
}
pub fn add_misc(
&mut self,
k: impl ToString,
v: serde_json::Value,
) -> Option<serde_json::Value> {
self.misc.insert(k.to_string(), v)
}
#[allow(unused)] pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
self.misc.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
self.constant_folder = Some(Box::new(fold));
}
#[must_use]
pub fn constant_fold(
&self,
type_args: &[TypeArg],
consts: &[(crate::IncomingPort, crate::ops::Value)],
) -> ConstFoldResult {
(self.constant_folder.as_ref())?.fold(type_args, consts)
}
#[must_use]
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
}
pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
&mut self.signature_func
}
}
impl Extension {
pub fn add_op(
&mut self,
name: OpName,
description: String,
signature_func: impl Into<SignatureFunc>,
extension_ref: &Weak<Extension>,
) -> Result<&mut OpDef, ExtensionBuildError> {
let op = OpDef {
extension: self.name.clone(),
extension_ref: extension_ref.clone(),
name,
description,
signature_func: signature_func.into(),
misc: Default::default(),
lower_funcs: Default::default(),
constant_folder: Default::default(),
};
match self.operations.entry(op.name.clone()) {
Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
}
}
}
#[cfg(test)]
pub(super) mod test {
use std::num::NonZeroU64;
use itertools::Itertools;
use super::SignatureFromArgs;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig};
use crate::extension::SignatureError;
use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc};
use crate::extension::prelude::usize_t;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::ops::OpName;
use crate::package::Package;
use crate::std_extensions::collections::list;
use crate::types::type_param::{TermTypeError, TypeParam};
use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
use crate::{Extension, const_extension_ids};
const_extension_ids! {
const EXT_ID: ExtensionId = "MyExt";
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub struct SimpleOpDef(OpDef);
impl SimpleOpDef {
#[must_use]
pub fn new(op_def: OpDef) -> Self {
assert!(op_def.constant_folder.is_none());
assert!(matches!(
op_def.signature_func,
SignatureFunc::PolyFuncType(_)
));
assert!(
op_def
.lower_funcs
.iter()
.all(|lf| matches!(lf, LowerFunc::FixedHugr { .. }))
);
Self(op_def)
}
}
impl From<SimpleOpDef> for OpDef {
fn from(value: SimpleOpDef) -> Self {
value.0
}
}
impl PartialEq for SimpleOpDef {
fn eq(&self, other: &Self) -> bool {
let OpDef {
extension,
extension_ref: _,
name,
description,
misc,
signature_func,
lower_funcs,
constant_folder: _,
} = &self.0;
let OpDef {
extension: other_extension,
extension_ref: _,
name: other_name,
description: other_description,
misc: other_misc,
signature_func: other_signature_func,
lower_funcs: other_lower_funcs,
constant_folder: _,
} = &other.0;
let get_sig = |sf: &_| match sf {
SignatureFunc::CustomValidator(CustomValidator {
poly_func,
validate: _,
})
| SignatureFunc::PolyFuncType(poly_func)
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
};
let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
lfs.iter()
.map(|lf| match lf {
LowerFunc::FixedHugr { extensions, pkg } => {
Some((extensions.clone(), pkg.clone()))
}
LowerFunc::CustomFunc(_) => None,
})
.collect_vec()
};
extension == other_extension
&& name == other_name
&& description == other_description
&& misc == other_misc
&& get_sig(signature_func) == get_sig(other_signature_func)
&& get_lower_funcs(lower_funcs) == get_lower_funcs(other_lower_funcs)
}
}
#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap();
const OP_NAME: OpName = OpName::new_inline("Reverse");
let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo([list_of_var]));
let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?;
def.add_lower_func(LowerFunc::FixedHugr {
extensions: ExtensionSet::new(),
pkg: Box::new(Package::from_hugr(crate::builder::test::simple_dfg_hugr())), });
def.add_misc("key", Default::default());
assert_eq!(def.description(), "desc");
assert_eq!(def.lower_funcs.len(), 1);
assert_eq!(def.misc.len(), 1);
Ok(())
})?;
let reg = ExtensionRegistry::new([PRELUDE.clone(), list::EXTENSION.clone(), ext]);
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();
let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
.unwrap(),
dfg.input_wires(),
)?;
dfg.finish_hugr_with_outputs(rev.outputs())?;
Ok(())
}
#[test]
fn binary_polyfunc() -> Result<(), Box<dyn std::error::Error>> {
struct SigFun();
impl SignatureFromArgs for SigFun {
fn compute_signature(
&self,
arg_values: &[TypeArg],
) -> Result<PolyFuncTypeRV, SignatureError> {
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Linear);
let [TypeArg::BoundedNat(n)] = arg_values else {
return Err(SignatureError::InvalidTypeArgs);
};
let n = *n as usize;
let tvs: Vec<Type> = (0..n)
.map(|_| Type::new_var_use(0, TypeBound::Linear))
.collect();
Ok(PolyFuncTypeRV::new(
vec![TP.clone()],
Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]),
))
}
fn static_params(&self) -> &[TypeParam] {
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
MAX_NAT
}
}
let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
let def: &mut crate::extension::OpDef =
ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
let args = [TypeArg::BoundedNat(3), usize_t().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
vec![usize_t(); 3],
vec![Type::new_tuple(vec![usize_t(); 3])]
))
);
assert_eq!(def.validate_args(&args, &[]), Ok(()));
let tyvar = Type::new_var_use(0, TypeBound::Copyable);
let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
tyvars.clone(),
vec![Type::new_tuple(tyvars)]
))
);
def.validate_args(&args, &[TypeBound::Copyable.into()])
.unwrap();
assert_eq!(
def.validate_args(&args, &[TypeBound::Linear.into()]),
Err(SignatureError::TypeVarDoesNotMatchDeclaration {
actual: Box::new(TypeBound::Linear.into()),
cached: Box::new(TypeBound::Copyable.into())
})
);
let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
assert_eq!(
def.compute_signature(&args),
Err(SignatureError::InvalidTypeArgs)
);
assert_eq!(
def.validate_args(&args, &[kind]),
Err(SignatureError::FreeTypeVar {
idx: 0,
num_decls: 0
})
);
Ok(())
})?;
Ok(())
}
#[test]
fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
let def = ext.add_op(
"SimpleOp".into(),
String::new(),
PolyFuncTypeRV::new(
vec![TypeBound::Linear.into()],
Signature::new_endo([Type::new_var_use(0, TypeBound::Linear)]),
),
extension_ref,
)?;
let tv = Type::new_var_use(0, TypeBound::Copyable);
let args = [tv.clone().into()];
let decls = [TypeBound::Copyable.into()];
def.validate_args(&args, &decls).unwrap();
assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo([tv])));
let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into();
assert_eq!(
def.compute_signature(std::slice::from_ref(&arg)),
Err(SignatureError::TypeArgMismatch(
TermTypeError::TypeMismatch {
type_: Box::new(TypeBound::Linear.into()),
term: Box::new(arg),
}
))
);
Ok(())
})?;
Ok(())
}
mod proptest {
use std::sync::Weak;
use super::SimpleOpDef;
use ::proptest::prelude::*;
use crate::package::Package;
use crate::{
builder::test::simple_dfg_hugr,
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc, op_def::LowerFunc},
types::PolyFuncTypeRV,
};
impl Arbitrary for SignatureFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<PolyFuncTypeRV>()
.prop_map(SignatureFunc::PolyFuncType)
.boxed()
}
}
impl Arbitrary for LowerFunc {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
any::<ExtensionSet>()
.prop_map(|extensions| LowerFunc::FixedHugr {
extensions,
pkg: Box::new(Package::from_hugr(simple_dfg_hugr())),
})
.boxed()
}
}
impl Arbitrary for SimpleOpDef {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use crate::proptest::{any_serde_json_value, any_smolstr, any_string};
use proptest::collection::{hash_map, vec};
let misc = hash_map(any_string(), any_serde_json_value(), 0..3);
(
any::<ExtensionId>(),
any_smolstr(),
any_string(),
misc,
any::<SignatureFunc>(),
vec(any::<LowerFunc>(), 0..2),
)
.prop_map(
|(extension, name, description, misc, signature_func, lower_funcs)| {
Self::new(OpDef {
extension,
extension_ref: Weak::default(),
name,
description,
misc,
signature_func,
lower_funcs,
constant_folder: None,
})
},
)
.boxed()
}
}
}
}