1use itertools::Itertools;
7use resolution::{ExtensionResolutionError, WeakExtensionRegistry};
8pub use semver::Version;
9use serde::{Deserialize, Deserializer, Serialize};
10use std::cell::UnsafeCell;
11use std::collections::btree_map;
12use std::collections::{BTreeMap, BTreeSet};
13use std::fmt::Debug;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::{Arc, Weak};
16use std::{io, mem};
17
18use derive_more::Display;
19use thiserror::Error;
20
21use crate::hugr::IdentList;
22use crate::ops::custom::{ExtensionOp, OpaqueOp};
23use crate::ops::{OpName, OpNameRef};
24use crate::types::RowVariable;
25use crate::types::type_param::{TermTypeError, TypeArg, TypeParam};
26use crate::types::{CustomType, TypeBound, TypeName};
27use crate::types::{Signature, TypeNameRef};
28
29mod const_fold;
30mod op_def;
31pub mod prelude;
32pub mod resolution;
33pub mod simple_op;
34mod type_def;
35
36pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
37pub use op_def::{
38 CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
39 ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs,
40};
41pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
42pub use type_def::{TypeDef, TypeDefBound};
43
44#[cfg(feature = "declarative")]
45pub mod declarative;
46
47#[derive(Debug, Display, Default)]
49#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
50pub struct ExtensionRegistry {
51 exts: BTreeMap<ExtensionId, Arc<Extension>>,
53 valid: AtomicBool,
60}
61
62impl PartialEq for ExtensionRegistry {
63 fn eq(&self, other: &Self) -> bool {
64 self.exts == other.exts
65 }
66}
67
68impl Clone for ExtensionRegistry {
69 fn clone(&self) -> Self {
70 Self {
71 exts: self.exts.clone(),
72 valid: self.valid.load(Ordering::Relaxed).into(),
73 }
74 }
75}
76
77impl ExtensionRegistry {
78 pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
80 let mut res = Self::default();
81 for ext in extensions {
82 res.register_updated(ext);
83 }
84 res
85 }
86
87 pub fn load_json(
93 reader: impl io::Read,
94 other_extensions: &ExtensionRegistry,
95 ) -> Result<Self, ExtensionRegistryLoadError> {
96 let extensions: Vec<Extension> = serde_json::from_reader(reader)?;
97 Ok(ExtensionRegistry::new_with_extension_resolution(
100 extensions,
101 &other_extensions.into(),
102 )?)
103 }
104
105 pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
107 self.exts.get(name)
108 }
109
110 pub fn contains(&self, name: &str) -> bool {
112 self.exts.contains_key(name)
113 }
114
115 pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
117 if self.valid.load(Ordering::Relaxed) {
118 return Ok(());
119 }
120 for ext in self.exts.values() {
121 ext.validate()
122 .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
123 }
124 self.valid.store(true, Ordering::Relaxed);
125 Ok(())
126 }
127
128 pub fn register(
132 &mut self,
133 extension: impl Into<Arc<Extension>>,
134 ) -> Result<(), ExtensionRegistryError> {
135 let extension = extension.into();
136 match self.exts.entry(extension.name().clone()) {
137 btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
138 extension.name().clone(),
139 Box::new(prev.get().version().clone()),
140 Box::new(extension.version().clone()),
141 )),
142 btree_map::Entry::Vacant(ve) => {
143 ve.insert(extension);
144 self.valid.store(false, Ordering::Relaxed);
146
147 Ok(())
148 }
149 }
150 }
151
152 pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
162 let extension = extension.into();
163 match self.exts.entry(extension.name().clone()) {
164 btree_map::Entry::Occupied(mut prev) => {
165 if prev.get().version() < extension.version() {
166 *prev.get_mut() = extension;
167 }
168 }
169 btree_map::Entry::Vacant(ve) => {
170 ve.insert(extension);
171 }
172 }
173 self.valid.store(false, Ordering::Relaxed);
175 }
176
177 pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
187 match self.exts.entry(extension.name().clone()) {
188 btree_map::Entry::Occupied(mut prev) => {
189 if prev.get().version() < extension.version() {
190 *prev.get_mut() = extension.clone();
191 }
192 }
193 btree_map::Entry::Vacant(ve) => {
194 ve.insert(extension.clone());
195 }
196 }
197 self.valid.store(false, Ordering::Relaxed);
199 }
200
201 pub fn len(&self) -> usize {
203 self.exts.len()
204 }
205
206 pub fn is_empty(&self) -> bool {
208 self.exts.is_empty()
209 }
210
211 pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
213 self.exts.values()
214 }
215
216 pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
218 self.exts.keys()
219 }
220
221 pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
223 self.valid.store(false, Ordering::Relaxed);
225
226 self.exts.remove(name)
227 }
228
229 pub fn new_cyclic<F, E>(
244 extensions: impl IntoIterator<Item = Extension>,
245 init: F,
246 ) -> Result<Self, E>
247 where
248 F: FnOnce(Vec<Extension>, &WeakExtensionRegistry) -> Result<Vec<Extension>, E>,
249 {
250 let extensions = extensions.into_iter().collect_vec();
251
252 #[repr(transparent)]
256 struct ExtensionCell {
257 ext: UnsafeCell<Extension>,
258 }
259
260 let (arcs, weaks): (Vec<Arc<ExtensionCell>>, Vec<Weak<Extension>>) = extensions
266 .iter()
267 .map(|ext| {
268 #[allow(clippy::arc_with_non_send_sync)]
274 let arc = Arc::new(ExtensionCell {
275 ext: UnsafeCell::new(Extension::new(ext.name().clone(), ext.version().clone())),
276 });
277
278 let weak_arc: Weak<Extension> = unsafe { mem::transmute(Arc::downgrade(&arc)) };
280 (arc, weak_arc)
281 })
282 .unzip();
283
284 let mut weak_registry = WeakExtensionRegistry::default();
285 for (ext, weak) in extensions.iter().zip(weaks) {
286 weak_registry.register(ext.name().clone(), weak);
287 }
288
289 let extensions = init(extensions, &weak_registry)?;
292
293 let arcs: Vec<Arc<Extension>> = arcs
295 .into_iter()
296 .zip(extensions)
297 .map(|(arc, ext)| {
298 unsafe { *arc.ext.get() = ext };
301 unsafe { mem::transmute::<Arc<ExtensionCell>, Arc<Extension>>(arc) }
304 })
305 .collect();
306 Ok(ExtensionRegistry::new(arcs))
307 }
308}
309
310impl IntoIterator for ExtensionRegistry {
311 type Item = Arc<Extension>;
312
313 type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;
314
315 fn into_iter(self) -> Self::IntoIter {
316 self.exts.into_values()
317 }
318}
319
320impl<'a> IntoIterator for &'a ExtensionRegistry {
321 type Item = &'a Arc<Extension>;
322
323 type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;
324
325 fn into_iter(self) -> Self::IntoIter {
326 self.exts.values()
327 }
328}
329
330impl<'a> Extend<&'a Arc<Extension>> for ExtensionRegistry {
331 fn extend<T: IntoIterator<Item = &'a Arc<Extension>>>(&mut self, iter: T) {
332 for ext in iter {
333 self.register_updated_ref(ext);
334 }
335 }
336}
337
338impl Extend<Arc<Extension>> for ExtensionRegistry {
339 fn extend<T: IntoIterator<Item = Arc<Extension>>>(&mut self, iter: T) {
340 for ext in iter {
341 self.register_updated(ext);
342 }
343 }
344}
345
346impl<'de> Deserialize<'de> for ExtensionRegistry {
351 fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
352 where
353 D: Deserializer<'de>,
354 {
355 let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
356 Ok(ExtensionRegistry::new(extensions))
357 }
358}
359
360impl Serialize for ExtensionRegistry {
361 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
362 where
363 S: serde::Serializer,
364 {
365 let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
366 extensions.serialize(serializer)
367 }
368}
369
370pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
372 exts: BTreeMap::new(),
373 valid: AtomicBool::new(true),
374};
375
376#[derive(Debug, Clone, Error, PartialEq, Eq)]
379#[allow(missing_docs)]
380#[non_exhaustive]
381pub enum SignatureError {
382 #[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
384 NameMismatch(TypeName, TypeName),
385 #[error("Definition extension ({0}) and instantiation extension ({1}) do not match.")]
387 ExtensionMismatch(ExtensionId, ExtensionId),
388 #[error("Type arguments of node did not match params declared by definition: {0}")]
390 TypeArgMismatch(#[from] TermTypeError),
391 #[error("Invalid type arguments for operation")]
393 InvalidTypeArgs,
394 #[error(
396 "Type '{typ}' is defined in extension '{missing}', but the extension reference has been dropped."
397 )]
398 MissingTypeExtension { typ: TypeName, missing: ExtensionId },
399 #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
401 ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
402 #[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
404 WrongBound {
405 actual: TypeBound,
406 expected: TypeBound,
407 },
408 #[error("Type Variable claims to be {cached} but actual declaration {actual}")]
410 TypeVarDoesNotMatchDeclaration {
411 actual: Box<TypeParam>,
412 cached: Box<TypeParam>,
413 },
414 #[error("Type variable {idx} was not declared ({num_decls} in scope)")]
416 FreeTypeVar { idx: usize, num_decls: usize },
417 #[error("Expected a single type, but found row variable {var}")]
419 RowVarWhereTypeExpected { var: RowVariable },
420 #[error(
425 "Incorrect result of type application in Call - cached {cached} but expected {expected}"
426 )]
427 CallIncorrectlyAppliesType {
428 cached: Box<Signature>,
429 expected: Box<Signature>,
430 },
431 #[error(
436 "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
437 )]
438 LoadFunctionIncorrectlyAppliesType {
439 cached: Box<Signature>,
440 expected: Box<Signature>,
441 },
442
443 #[error("Binary compute signature function not loaded.")]
446 MissingComputeFunc,
447
448 #[error("Binary validate signature function not loaded.")]
451 MissingValidateFunc,
452}
453
454trait CustomConcrete {
456 type Identifier;
458 fn def_name(&self) -> &Self::Identifier;
462 fn type_args(&self) -> &[TypeArg];
464 fn parent_extension(&self) -> &ExtensionId;
466}
467
468impl CustomConcrete for OpaqueOp {
469 type Identifier = OpName;
470
471 fn def_name(&self) -> &Self::Identifier {
472 self.unqualified_id()
473 }
474
475 fn type_args(&self) -> &[TypeArg] {
476 self.args()
477 }
478
479 fn parent_extension(&self) -> &ExtensionId {
480 self.extension()
481 }
482}
483
484impl CustomConcrete for CustomType {
485 type Identifier = TypeName;
486
487 fn def_name(&self) -> &TypeName {
488 self.name()
490 }
491
492 fn type_args(&self) -> &[TypeArg] {
493 self.args()
494 }
495
496 fn parent_extension(&self) -> &ExtensionId {
497 self.extension()
498 }
499}
500
501pub type ExtensionId = IdentList;
505
506#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
547pub struct Extension {
548 pub version: Version,
550 pub name: ExtensionId,
552 types: BTreeMap<TypeName, TypeDef>,
554 operations: BTreeMap<OpName, Arc<op_def::OpDef>>,
561}
562
563impl Extension {
564 #[must_use]
572 pub fn new(name: ExtensionId, version: Version) -> Self {
573 Self {
574 name,
575 version,
576 types: Default::default(),
577 operations: Default::default(),
578 }
579 }
580
581 pub fn new_arc(
587 name: ExtensionId,
588 version: Version,
589 init: impl FnOnce(&mut Extension, &Weak<Extension>),
590 ) -> Arc<Self> {
591 Arc::new_cyclic(|extension_ref| {
592 let mut ext = Self::new(name, version);
593 init(&mut ext, extension_ref);
594 ext
595 })
596 }
597
598 pub fn try_new_arc<E>(
605 name: ExtensionId,
606 version: Version,
607 init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
608 ) -> Result<Arc<Self>, E> {
609 let mut error = None;
616 let ext = Arc::new_cyclic(|extension_ref| {
617 let mut ext = Self::new(name, version);
618 match init(&mut ext, extension_ref) {
619 Ok(()) => ext,
620 Err(e) => {
621 error = Some(e);
622 ext
623 }
624 }
625 });
626 match error {
627 Some(e) => Err(e),
628 None => Ok(ext),
629 }
630 }
631
632 #[must_use]
634 pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
635 self.operations.get(name)
636 }
637
638 #[must_use]
640 pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
641 self.types.get(type_name)
642 }
643
644 #[must_use]
646 pub fn name(&self) -> &ExtensionId {
647 &self.name
648 }
649
650 #[must_use]
652 pub fn version(&self) -> &Version {
653 &self.version
654 }
655
656 pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
658 self.operations.iter()
659 }
660
661 pub fn types(&self) -> impl Iterator<Item = (&TypeName, &TypeDef)> {
663 self.types.iter()
664 }
665
666 pub fn instantiate_extension_op(
668 &self,
669 name: &OpNameRef,
670 args: impl Into<Vec<TypeArg>>,
671 ) -> Result<ExtensionOp, SignatureError> {
672 let op_def = self.get_op(name).expect("Op not found.");
673 ExtensionOp::new(op_def.clone(), args)
674 }
675
676 fn validate(&self) -> Result<(), SignatureError> {
678 for op_def in self.operations.values() {
680 op_def.validate()?;
681 }
682 Ok(())
683 }
684}
685
686impl PartialEq for Extension {
687 fn eq(&self, other: &Self) -> bool {
688 self.name == other.name && self.version == other.version
689 }
690}
691
692#[derive(Debug, Clone, Error, PartialEq, Eq)]
694#[non_exhaustive]
695pub enum ExtensionRegistryError {
696 #[error(
698 "The registry already contains an extension with id {0} and version {1}. New extension has version {2}."
699 )]
700 AlreadyRegistered(ExtensionId, Box<Version>, Box<Version>),
701 #[error("The extension {0} contains an invalid signature, {1}.")]
703 InvalidSignature(ExtensionId, #[source] SignatureError),
704}
705
706#[derive(Debug, Error)]
708#[non_exhaustive]
709#[error("Extension registry load error")]
710pub enum ExtensionRegistryLoadError {
711 #[error(transparent)]
713 SerdeError(#[from] serde_json::Error),
714 #[error(transparent)]
716 ExtensionResolutionError(Box<ExtensionResolutionError>),
717}
718
719impl From<ExtensionResolutionError> for ExtensionRegistryLoadError {
720 fn from(error: ExtensionResolutionError) -> Self {
721 Self::ExtensionResolutionError(Box::new(error))
722 }
723}
724
725#[derive(Debug, Clone, Error, PartialEq, Eq)]
727#[non_exhaustive]
728pub enum ExtensionBuildError {
729 #[error("Extension already has an op called {0}.")]
731 OpDefExists(OpName),
732 #[error("Extension already has an type called {0}.")]
734 TypeDefExists(TypeName),
735}
736
737#[derive(
739 Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
740)]
741#[display("[{}]", _0.iter().join(", "))]
742pub struct ExtensionSet(BTreeSet<ExtensionId>);
743
744impl ExtensionSet {
745 #[must_use]
747 pub const fn new() -> Self {
748 Self(BTreeSet::new())
749 }
750
751 pub fn insert(&mut self, extension: ExtensionId) {
753 self.0.insert(extension.clone());
754 }
755
756 #[must_use]
758 pub fn contains(&self, extension: &ExtensionId) -> bool {
759 self.0.contains(extension)
760 }
761
762 #[must_use]
764 pub fn is_subset(&self, other: &Self) -> bool {
765 self.0.is_subset(&other.0)
766 }
767
768 #[must_use]
770 pub fn is_superset(&self, other: &Self) -> bool {
771 self.0.is_superset(&other.0)
772 }
773
774 #[must_use]
776 pub fn singleton(extension: ExtensionId) -> Self {
777 let mut set = Self::new();
778 set.insert(extension);
779 set
780 }
781
782 #[must_use]
784 pub fn union(mut self, other: Self) -> Self {
785 self.0.extend(other.0);
786 self
787 }
788
789 pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
791 let mut res = ExtensionSet::new();
793 for s in sets {
794 res.0.extend(s.0);
795 }
796 res
797 }
798
799 #[must_use]
801 pub fn missing_from(&self, other: &Self) -> Self {
802 ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
803 }
804
805 pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
807 self.0.iter()
808 }
809
810 #[must_use]
812 pub fn is_empty(&self) -> bool {
813 self.0.is_empty()
814 }
815}
816
817impl From<ExtensionId> for ExtensionSet {
818 fn from(id: ExtensionId) -> Self {
819 Self::singleton(id)
820 }
821}
822
823impl IntoIterator for ExtensionSet {
824 type Item = ExtensionId;
825 type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;
826
827 fn into_iter(self) -> Self::IntoIter {
828 self.0.into_iter()
829 }
830}
831
832impl<'a> IntoIterator for &'a ExtensionSet {
833 type Item = &'a ExtensionId;
834 type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;
835
836 fn into_iter(self) -> Self::IntoIter {
837 self.0.iter()
838 }
839}
840
841impl FromIterator<ExtensionId> for ExtensionSet {
842 fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
843 Self(BTreeSet::from_iter(iter))
844 }
845}
846
847#[cfg(test)]
849pub mod test {
850 pub use super::op_def::test::SimpleOpDef;
852
853 use super::*;
854
855 impl Extension {
856 pub(crate) fn new_test_arc(
858 name: ExtensionId,
859 init: impl FnOnce(&mut Extension, &Weak<Extension>),
860 ) -> Arc<Self> {
861 Self::new_arc(name, Version::new(0, 0, 0), init)
862 }
863
864 pub(crate) fn try_new_test_arc(
866 name: ExtensionId,
867 init: impl FnOnce(
868 &mut Extension,
869 &Weak<Extension>,
870 ) -> Result<(), Box<dyn std::error::Error>>,
871 ) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
872 Self::try_new_arc(name, Version::new(0, 0, 0), init)
873 }
874 }
875
876 #[test]
877 fn test_register_update() {
878 let mut reg = ExtensionRegistry::default();
881 let mut reg_ref = ExtensionRegistry::default();
882
883 let ext_1_id = ExtensionId::new("ext1").unwrap();
884 let ext_2_id = ExtensionId::new("ext2").unwrap();
885 let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
886 let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
887 let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
888 let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));
889
890 reg.register(ext1.clone()).unwrap();
891 reg_ref.register(ext1.clone()).unwrap();
892 assert_eq!(®, ®_ref);
893
894 assert_eq!(
896 reg.register(ext1_1.clone()),
897 Err(ExtensionRegistryError::AlreadyRegistered(
898 ext_1_id.clone(),
899 Box::new(Version::new(1, 0, 0)),
900 Box::new(Version::new(1, 1, 0))
901 ))
902 );
903
904 reg_ref.register_updated_ref(&ext1_1);
906 reg.register_updated(ext1_1.clone());
907 assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
908 assert_eq!(®, ®_ref);
909
910 reg_ref.register_updated_ref(&ext1_2);
912 reg.register_updated(ext1_2.clone());
913 assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
914 assert_eq!(®, ®_ref);
915
916 reg.register(ext2.clone()).unwrap();
917 assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
918 assert_eq!(reg.len(), 2);
919
920 assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
921 assert_eq!(reg.len(), 1);
922 }
923
924 mod proptest {
925
926 use ::proptest::{collection::hash_set, prelude::*};
927
928 use super::super::{ExtensionId, ExtensionSet};
929
930 impl Arbitrary for ExtensionSet {
931 type Parameters = ();
932 type Strategy = BoxedStrategy<Self>;
933
934 fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
935 hash_set(any::<ExtensionId>(), 0..3)
936 .prop_map(|extensions| extensions.into_iter().collect::<ExtensionSet>())
937 .boxed()
938 }
939 }
940 }
941}