use itertools::Itertools;
use resolution::{ExtensionResolutionError, WeakExtensionRegistry};
pub use semver::Version;
use serde::{Deserialize, Deserializer, Serialize};
use std::cell::UnsafeCell;
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
use std::{io, mem};
use derive_more::Display;
use thiserror::Error;
use crate::hugr::IdentList;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{OpName, OpNameRef};
use crate::types::RowVariable;
use crate::types::type_param::{TermTypeError, TypeArg, TypeParam};
use crate::types::{CustomType, TypeBound, TypeName};
use crate::types::{Signature, TypeNameRef};
mod const_fold;
mod op_def;
pub mod prelude;
pub mod resolution;
pub mod simple_op;
mod type_def;
pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs,
};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub use type_def::{TypeDef, TypeDefBound};
#[cfg(feature = "declarative")]
pub mod declarative;
#[derive(Debug, Display, Default)]
#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
pub struct ExtensionRegistry {
exts: BTreeMap<ExtensionId, Arc<Extension>>,
valid: AtomicBool,
}
impl PartialEq for ExtensionRegistry {
fn eq(&self, other: &Self) -> bool {
self.exts == other.exts
}
}
impl Clone for ExtensionRegistry {
fn clone(&self) -> Self {
Self {
exts: self.exts.clone(),
valid: self.valid.load(Ordering::Relaxed).into(),
}
}
}
impl ExtensionRegistry {
pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
let mut res = Self::default();
for ext in extensions {
res.register_updated(ext);
}
res
}
pub fn load_json(
reader: impl io::Read,
other_extensions: &ExtensionRegistry,
) -> Result<Self, ExtensionRegistryLoadError> {
let extensions: Vec<Extension> = serde_json::from_reader(reader)?;
Ok(ExtensionRegistry::new_with_extension_resolution(
extensions,
&other_extensions.into(),
)?)
}
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.exts.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.exts.contains_key(name)
}
pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
if self.valid.load(Ordering::Relaxed) {
return Ok(());
}
for ext in self.exts.values() {
ext.validate()
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}
self.valid.store(true, Ordering::Relaxed);
Ok(())
}
pub fn register(
&mut self,
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
Box::new(prev.get().version().clone()),
Box::new(extension.version().clone()),
)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
self.valid.store(false, Ordering::Relaxed);
Ok(())
}
}
}
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
}
}
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
}
}
self.valid.store(false, Ordering::Relaxed);
}
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
}
}
btree_map::Entry::Vacant(ve) => {
ve.insert(extension.clone());
}
}
self.valid.store(false, Ordering::Relaxed);
}
pub fn len(&self) -> usize {
self.exts.len()
}
pub fn is_empty(&self) -> bool {
self.exts.is_empty()
}
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.exts.values()
}
pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
self.exts.keys()
}
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
self.valid.store(false, Ordering::Relaxed);
self.exts.remove(name)
}
pub fn new_cyclic<F, E>(
extensions: impl IntoIterator<Item = Extension>,
init: F,
) -> Result<Self, E>
where
F: FnOnce(Vec<Extension>, &WeakExtensionRegistry) -> Result<Vec<Extension>, E>,
{
let extensions = extensions.into_iter().collect_vec();
#[repr(transparent)]
struct ExtensionCell {
ext: UnsafeCell<Extension>,
}
let (arcs, weaks): (Vec<Arc<ExtensionCell>>, Vec<Weak<Extension>>) = extensions
.iter()
.map(|ext| {
#[allow(clippy::arc_with_non_send_sync)]
let arc = Arc::new(ExtensionCell {
ext: UnsafeCell::new(Extension::new(ext.name().clone(), ext.version().clone())),
});
let weak_arc: Weak<Extension> = unsafe { mem::transmute(Arc::downgrade(&arc)) };
(arc, weak_arc)
})
.unzip();
let mut weak_registry = WeakExtensionRegistry::default();
for (ext, weak) in extensions.iter().zip(weaks) {
weak_registry.register(ext.name().clone(), weak);
}
let extensions = init(extensions, &weak_registry)?;
let arcs: Vec<Arc<Extension>> = arcs
.into_iter()
.zip(extensions)
.map(|(arc, ext)| {
unsafe { *arc.ext.get() = ext };
unsafe { mem::transmute::<Arc<ExtensionCell>, Arc<Extension>>(arc) }
})
.collect();
Ok(ExtensionRegistry::new(arcs))
}
}
impl IntoIterator for ExtensionRegistry {
type Item = Arc<Extension>;
type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;
fn into_iter(self) -> Self::IntoIter {
self.exts.into_values()
}
}
impl<'a> IntoIterator for &'a ExtensionRegistry {
type Item = &'a Arc<Extension>;
type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;
fn into_iter(self) -> Self::IntoIter {
self.exts.values()
}
}
impl<'a> Extend<&'a Arc<Extension>> for ExtensionRegistry {
fn extend<T: IntoIterator<Item = &'a Arc<Extension>>>(&mut self, iter: T) {
for ext in iter {
self.register_updated_ref(ext);
}
}
}
impl Extend<Arc<Extension>> for ExtensionRegistry {
fn extend<T: IntoIterator<Item = Arc<Extension>>>(&mut self, iter: T) {
for ext in iter {
self.register_updated(ext);
}
}
}
impl<'de> Deserialize<'de> for ExtensionRegistry {
fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
where
D: Deserializer<'de>,
{
let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
Ok(ExtensionRegistry::new(extensions))
}
}
impl Serialize for ExtensionRegistry {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
extensions.serialize(serializer)
}
}
pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
exts: BTreeMap::new(),
valid: AtomicBool::new(true),
};
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[allow(missing_docs)]
#[non_exhaustive]
pub enum SignatureError {
#[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
NameMismatch(TypeName, TypeName),
#[error("Definition extension ({0}) and instantiation extension ({1}) do not match.")]
ExtensionMismatch(ExtensionId, ExtensionId),
#[error("Type arguments of node did not match params declared by definition: {0}")]
TypeArgMismatch(#[from] TermTypeError),
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
#[error(
"Type '{typ}' is defined in extension '{missing}', but the extension reference has been dropped."
)]
MissingTypeExtension { typ: TypeName, missing: ExtensionId },
#[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
#[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
WrongBound {
actual: TypeBound,
expected: TypeBound,
},
#[error("Type Variable claims to be {cached} but actual declaration {actual}")]
TypeVarDoesNotMatchDeclaration {
actual: Box<TypeParam>,
cached: Box<TypeParam>,
},
#[error("Type variable {idx} was not declared ({num_decls} in scope)")]
FreeTypeVar { idx: usize, num_decls: usize },
#[error("Expected a single type, but found row variable {var}")]
RowVarWhereTypeExpected { var: RowVariable },
#[error(
"Incorrect result of type application in Call - cached {cached} but expected {expected}"
)]
CallIncorrectlyAppliesType {
cached: Box<Signature>,
expected: Box<Signature>,
},
#[error(
"Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
)]
LoadFunctionIncorrectlyAppliesType {
cached: Box<Signature>,
expected: Box<Signature>,
},
#[error("Binary compute signature function not loaded.")]
MissingComputeFunc,
#[error("Binary validate signature function not loaded.")]
MissingValidateFunc,
}
trait CustomConcrete {
type Identifier;
fn def_name(&self) -> &Self::Identifier;
fn type_args(&self) -> &[TypeArg];
fn parent_extension(&self) -> &ExtensionId;
}
impl CustomConcrete for OpaqueOp {
type Identifier = OpName;
fn def_name(&self) -> &Self::Identifier {
self.unqualified_id()
}
fn type_args(&self) -> &[TypeArg] {
self.args()
}
fn parent_extension(&self) -> &ExtensionId {
self.extension()
}
}
impl CustomConcrete for CustomType {
type Identifier = TypeName;
fn def_name(&self) -> &TypeName {
self.name()
}
fn type_args(&self) -> &[TypeArg] {
self.args()
}
fn parent_extension(&self) -> &ExtensionId {
self.extension()
}
}
pub type ExtensionId = IdentList;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
pub version: Version,
pub name: ExtensionId,
types: BTreeMap<TypeName, TypeDef>,
operations: BTreeMap<OpName, Arc<op_def::OpDef>>,
}
impl Extension {
#[must_use]
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
version,
types: Default::default(),
operations: Default::default(),
}
}
pub fn new_arc(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
init(&mut ext, extension_ref);
ext
})
}
pub fn try_new_arc<E>(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
) -> Result<Arc<Self>, E> {
let mut error = None;
let ext = Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
match init(&mut ext, extension_ref) {
Ok(()) => ext,
Err(e) => {
error = Some(e);
ext
}
}
});
match error {
Some(e) => Err(e),
None => Ok(ext),
}
}
#[must_use]
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
}
#[must_use]
pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
self.types.get(type_name)
}
#[must_use]
pub fn name(&self) -> &ExtensionId {
&self.name
}
#[must_use]
pub fn version(&self) -> &Version {
&self.version
}
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
}
pub fn types(&self) -> impl Iterator<Item = (&TypeName, &TypeDef)> {
self.types.iter()
}
pub fn instantiate_extension_op(
&self,
name: &OpNameRef,
args: impl Into<Vec<TypeArg>>,
) -> Result<ExtensionOp, SignatureError> {
let op_def = self.get_op(name).expect("Op not found.");
ExtensionOp::new(op_def.clone(), args)
}
fn validate(&self) -> Result<(), SignatureError> {
for op_def in self.operations.values() {
op_def.validate()?;
}
Ok(())
}
}
impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.version == other.version
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtensionRegistryError {
#[error(
"The registry already contains an extension with id {0} and version {1}. New extension has version {2}."
)]
AlreadyRegistered(ExtensionId, Box<Version>, Box<Version>),
#[error("The extension {0} contains an invalid signature, {1}.")]
InvalidSignature(ExtensionId, #[source] SignatureError),
}
#[derive(Debug, Error)]
#[non_exhaustive]
#[error("Extension registry load error")]
pub enum ExtensionRegistryLoadError {
#[error(transparent)]
SerdeError(#[from] serde_json::Error),
#[error(transparent)]
ExtensionResolutionError(Box<ExtensionResolutionError>),
}
impl From<ExtensionResolutionError> for ExtensionRegistryLoadError {
fn from(error: ExtensionResolutionError) -> Self {
Self::ExtensionResolutionError(Box::new(error))
}
}
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtensionBuildError {
#[error("Extension already has an op called {0}.")]
OpDefExists(OpName),
#[error("Extension already has an type called {0}.")]
TypeDefExists(TypeName),
}
#[derive(
Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
)]
#[display("[{}]", _0.iter().join(", "))]
pub struct ExtensionSet(BTreeSet<ExtensionId>);
impl ExtensionSet {
#[must_use]
pub const fn new() -> Self {
Self(BTreeSet::new())
}
pub fn insert(&mut self, extension: ExtensionId) {
self.0.insert(extension.clone());
}
#[must_use]
pub fn contains(&self, extension: &ExtensionId) -> bool {
self.0.contains(extension)
}
#[must_use]
pub fn is_subset(&self, other: &Self) -> bool {
self.0.is_subset(&other.0)
}
#[must_use]
pub fn is_superset(&self, other: &Self) -> bool {
self.0.is_superset(&other.0)
}
#[must_use]
pub fn singleton(extension: ExtensionId) -> Self {
let mut set = Self::new();
set.insert(extension);
set
}
#[must_use]
pub fn union(mut self, other: Self) -> Self {
self.0.extend(other.0);
self
}
pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
let mut res = ExtensionSet::new();
for s in sets {
res.0.extend(s.0);
}
res
}
#[must_use]
pub fn missing_from(&self, other: &Self) -> Self {
ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
}
pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.iter()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl From<ExtensionId> for ExtensionSet {
fn from(id: ExtensionId) -> Self {
Self::singleton(id)
}
}
impl IntoIterator for ExtensionSet {
type Item = ExtensionId;
type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<'a> IntoIterator for &'a ExtensionSet {
type Item = &'a ExtensionId;
type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}
impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(BTreeSet::from_iter(iter))
}
}
#[cfg(test)]
pub mod test {
pub use super::op_def::test::SimpleOpDef;
use super::*;
impl Extension {
pub(crate) fn new_test_arc(
name: ExtensionId,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Self::new_arc(name, Version::new(0, 0, 0), init)
}
pub(crate) fn try_new_test_arc(
name: ExtensionId,
init: impl FnOnce(
&mut Extension,
&Weak<Extension>,
) -> Result<(), Box<dyn std::error::Error>>,
) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
Self::try_new_arc(name, Version::new(0, 0, 0), init)
}
}
#[test]
fn test_register_update() {
let mut reg = ExtensionRegistry::default();
let mut reg_ref = ExtensionRegistry::default();
let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));
reg.register(ext1.clone()).unwrap();
reg_ref.register(ext1.clone()).unwrap();
assert_eq!(®, ®_ref);
assert_eq!(
reg.register(ext1_1.clone()),
Err(ExtensionRegistryError::AlreadyRegistered(
ext_1_id.clone(),
Box::new(Version::new(1, 0, 0)),
Box::new(Version::new(1, 1, 0))
))
);
reg_ref.register_updated_ref(&ext1_1);
reg.register_updated(ext1_1.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(®, ®_ref);
reg_ref.register_updated_ref(&ext1_2);
reg.register_updated(ext1_2.clone());
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(®, ®_ref);
reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
assert_eq!(reg.len(), 2);
assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
assert_eq!(reg.len(), 1);
}
mod proptest {
use ::proptest::{collection::hash_set, prelude::*};
use super::super::{ExtensionId, ExtensionSet};
impl Arbitrary for ExtensionSet {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
hash_set(any::<ExtensionId>(), 0..3)
.prop_map(|extensions| extensions.into_iter().collect::<ExtensionSet>())
.boxed()
}
}
}
}