use std::collections::btree_map;
use std::collections::hash_map;
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use thiserror::Error;
use crate::hugr::IdentList;
use crate::ops::constant::{ValueName, ValueNameRef};
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};
use crate::types::{FunctionType, TypeNameRef};
#[allow(dead_code)]
mod infer;
#[cfg(feature = "extension_inference")]
pub use infer::infer_extensions;
pub use infer::{ExtensionSolution, InferExtensionError};
mod op_def;
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub mod declarative;
#[derive(Clone, Debug)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);
impl ExtensionRegistry {
pub fn get(&self, name: &str) -> Option<&Extension> {
self.0.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
}
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
) -> Result<Self, ExtensionRegistryError> {
let mut exts = BTreeMap::new();
for ext in value.into_iter() {
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
return Err(ExtensionRegistryError::AlreadyRegistered(
prev.name().clone(),
));
};
}
let res = ExtensionRegistry(exts);
for ext in res.0.values() {
ext.validate(&res)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}
Ok(res)
}
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
)),
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Extension)> {
self.0.iter()
}
}
impl IntoIterator for ExtensionRegistry {
type Item = (ExtensionId, Extension);
type IntoIter = <BTreeMap<ExtensionId, Extension> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum SignatureError {
#[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
NameMismatch(TypeName, TypeName),
#[error("Definition extension ({0:?}) and instantiation extension ({1:?}) do not match.")]
ExtensionMismatch(ExtensionId, ExtensionId),
#[error("Type arguments of node did not match params declared by definition: {0}")]
TypeArgMismatch(#[from] TypeArgError),
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
#[error("Extension '{0}' not found")]
ExtensionNotFound(ExtensionId),
#[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
#[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
WrongBound {
actual: TypeBound,
expected: TypeBound,
},
#[error("Type Variable claims to be {cached:?} but actual declaration {actual:?}")]
TypeVarDoesNotMatchDeclaration {
actual: TypeParam,
cached: TypeParam,
},
#[error("Type variable {idx} was not declared ({num_decls} in scope)")]
FreeTypeVar { idx: usize, num_decls: usize },
#[error("Expected a single type, but found row variable {idx}")]
RowVarWhereTypeExpected { idx: usize },
#[error(
"Incorrect result of type application in Call - cached {cached} but expected {expected}"
)]
CallIncorrectlyAppliesType {
cached: FunctionType,
expected: FunctionType,
},
#[error(
"Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
)]
LoadFunctionIncorrectlyAppliesType {
cached: FunctionType,
expected: FunctionType,
},
}
trait CustomConcrete {
type Identifier;
fn def_name(&self) -> &Self::Identifier;
fn type_args(&self) -> &[TypeArg];
fn parent_extension(&self) -> &ExtensionId;
}
impl CustomConcrete for OpaqueOp {
type Identifier = OpName;
fn def_name(&self) -> &OpName {
self.name()
}
fn type_args(&self) -> &[TypeArg] {
self.args()
}
fn parent_extension(&self) -> &ExtensionId {
self.extension()
}
}
impl CustomConcrete for CustomType {
type Identifier = TypeName;
fn def_name(&self) -> &TypeName {
self.name()
}
fn type_args(&self) -> &[TypeArg] {
self.args()
}
fn parent_extension(&self) -> &ExtensionId {
self.extension()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ExtensionValue {
extension: ExtensionId,
name: ValueName,
typed_value: ops::Value,
}
impl ExtensionValue {
pub fn typed_value(&self) -> &ops::Value {
&self.typed_value
}
pub fn name(&self) -> &str {
self.name.as_str()
}
pub fn extension(&self) -> &ExtensionId {
&self.extension
}
}
pub type ExtensionId = IdentList;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
pub name: ExtensionId,
pub extension_reqs: ExtensionSet,
types: HashMap<TypeName, TypeDef>,
values: HashMap<ValueName, ExtensionValue>,
operations: HashMap<OpName, Arc<op_def::OpDef>>,
}
impl Extension {
pub fn new(name: ExtensionId) -> Self {
Self::new_with_reqs(name, ExtensionSet::default())
}
pub fn new_with_reqs(name: ExtensionId, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
name,
extension_reqs: extension_reqs.into(),
types: Default::default(),
values: Default::default(),
operations: Default::default(),
}
}
pub fn get_op(&self, op_name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(op_name)
}
pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
self.types.get(type_name)
}
pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> {
self.values.get(value_name)
}
pub fn name(&self) -> &ExtensionId {
&self.name
}
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
}
pub fn types(&self) -> impl Iterator<Item = (&TypeName, &TypeDef)> {
self.types.iter()
}
pub fn add_value(
&mut self,
name: impl Into<ValueName>,
typed_value: ops::Value,
) -> Result<&mut ExtensionValue, ExtensionBuildError> {
let extension_value = ExtensionValue {
extension: self.name.clone(),
name: name.into(),
typed_value,
};
match self.values.entry(extension_value.name.clone()) {
hash_map::Entry::Occupied(_) => {
Err(ExtensionBuildError::ValueExists(extension_value.name))
}
hash_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)),
}
}
pub fn instantiate_extension_op(
&self,
op_name: &OpNameRef,
args: impl Into<Vec<TypeArg>>,
ext_reg: &ExtensionRegistry,
) -> Result<ExtensionOp, SignatureError> {
let op_def = self.get_op(op_name).expect("Op not found.");
ExtensionOp::new(op_def.clone(), args, ext_reg)
}
fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> {
for op_def in self.operations.values() {
op_def.validate(all_exts)?;
}
Ok(())
}
}
impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ExtensionRegistryError {
#[error("The registry already contains an extension with id {0}.")]
AlreadyRegistered(ExtensionId),
#[error("The extension {0} contains an invalid signature, {1}.")]
InvalidSignature(ExtensionId, #[source] SignatureError),
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ExtensionBuildError {
#[error("Extension already has an op called {0}.")]
OpDefExists(OpName),
#[error("Extension already has an type called {0}.")]
TypeDefExists(TypeName),
#[error("Extension already has an extension value called {0}.")]
ValueExists(ValueName),
}
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct ExtensionSet(BTreeSet<ExtensionId>);
impl ExtensionSet {
pub const fn new() -> Self {
Self(BTreeSet::new())
}
pub fn insert(&mut self, extension: &ExtensionId) {
self.0.insert(extension.clone());
}
pub fn insert_type_var(&mut self, idx: usize) {
self.0
.insert(ExtensionId::new_unchecked(idx.to_string().as_str()));
}
pub fn contains(&self, extension: &ExtensionId) -> bool {
self.0.contains(extension)
}
pub fn is_subset(&self, other: &Self) -> bool {
self.0.is_subset(&other.0)
}
pub fn is_superset(&self, other: &Self) -> bool {
self.0.is_superset(&other.0)
}
pub fn singleton(extension: &ExtensionId) -> Self {
let mut set = Self::new();
set.insert(extension);
set
}
pub fn type_var(idx: usize) -> Self {
let mut set = Self::new();
set.insert_type_var(idx);
set
}
pub fn union(mut self, other: Self) -> Self {
self.0.extend(other.0);
self
}
pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
let mut res = ExtensionSet::new();
for s in sets {
res.0.extend(s.0)
}
res
}
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
}
pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.iter()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> {
self.iter()
.filter_map(as_typevar)
.try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions))
}
pub(crate) fn substitute(&self, t: &Substitution) -> Self {
Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) {
None => vec![e.clone()],
Some(i) => match t.apply_var(i, &TypeParam::Extensions) {
TypeArg::Extensions{es} => es.iter().cloned().collect::<Vec<_>>(),
_ => panic!("value for type var was not extension set - type scheme should be validated first"),
},
}))
}
}
impl From<ExtensionId> for ExtensionSet {
fn from(id: ExtensionId) -> Self {
Self::singleton(&id)
}
}
fn as_typevar(e: &ExtensionId) -> Option<usize> {
match e.chars().next() {
Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()),
_ => None,
}
}
impl Display for ExtensionSet {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(BTreeSet::from_iter(iter))
}
}
#[cfg(test)]
pub mod test {
pub use super::op_def::test::SimpleOpDef;
mod proptest {
use ::proptest::{collection::hash_set, prelude::*};
use super::super::{ExtensionId, ExtensionSet};
impl Arbitrary for ExtensionSet {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(
hash_set(0..10usize, 0..3),
hash_set(any::<ExtensionId>(), 0..3),
)
.prop_map(|(vars, extensions)| {
ExtensionSet::union_over(
std::iter::once(extensions.into_iter().collect::<ExtensionSet>())
.chain(vars.into_iter().map(ExtensionSet::type_var)),
)
})
.boxed()
}
}
}
}