hugr_core/
extension.rs

1//! Extensions
2//!
3//! TODO: YAML declaration and parsing. This should be similar to a plugin
4//! system (outside the `types` module), which also parses nested [`OpDef`]s.
5
6use itertools::Itertools;
7use resolution::{ExtensionResolutionError, WeakExtensionRegistry};
8pub use semver::Version;
9use serde::{Deserialize, Deserializer, Serialize};
10use std::cell::UnsafeCell;
11use std::collections::btree_map;
12use std::collections::{BTreeMap, BTreeSet};
13use std::fmt::Debug;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::{Arc, Weak};
16use std::{io, mem};
17
18use derive_more::Display;
19use thiserror::Error;
20
21use crate::hugr::IdentList;
22use crate::ops::custom::{ExtensionOp, OpaqueOp};
23use crate::ops::{OpName, OpNameRef};
24use crate::types::RowVariable;
25use crate::types::type_param::{TermTypeError, TypeArg, TypeParam};
26use crate::types::{CustomType, TypeBound, TypeName};
27use crate::types::{Signature, TypeNameRef};
28
29mod const_fold;
30mod op_def;
31pub mod prelude;
32pub mod resolution;
33pub mod simple_op;
34mod type_def;
35
36pub use const_fold::{ConstFold, ConstFoldResult, Folder, fold_out_row};
37pub use op_def::{
38    CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
39    ValidateJustArgs, ValidateTypeArgs, deserialize_lower_funcs,
40};
41pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
42pub use type_def::{TypeDef, TypeDefBound};
43
44#[cfg(feature = "declarative")]
45pub mod declarative;
46
47/// Extension Registries store extensions to be looked up e.g. during validation.
48#[derive(Debug, Display, Default)]
49#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
50pub struct ExtensionRegistry {
51    /// The extensions in the registry.
52    exts: BTreeMap<ExtensionId, Arc<Extension>>,
53    /// A flag indicating whether the current set of extensions has been
54    /// validated.
55    ///
56    /// This is used to avoid re-validating the extensions every time the
57    /// registry is validated, and is set to `false` whenever a new extension is
58    /// added.
59    valid: AtomicBool,
60}
61
62impl PartialEq for ExtensionRegistry {
63    fn eq(&self, other: &Self) -> bool {
64        self.exts == other.exts
65    }
66}
67
68impl Clone for ExtensionRegistry {
69    fn clone(&self) -> Self {
70        Self {
71            exts: self.exts.clone(),
72            valid: self.valid.load(Ordering::Relaxed).into(),
73        }
74    }
75}
76
77impl ExtensionRegistry {
78    /// Create a new empty extension registry.
79    pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
80        let mut res = Self::default();
81        for ext in extensions {
82            res.register_updated(ext);
83        }
84        res
85    }
86
87    /// Load an `ExtensionRegistry` serialized as json.
88    ///
89    /// After deserialization, updates all the internal `Weak<Extension>`
90    /// references to point to the newly created [`Arc`]s in the registry,
91    /// or extensions in the `additional_extensions` parameter.
92    pub fn load_json(
93        reader: impl io::Read,
94        other_extensions: &ExtensionRegistry,
95    ) -> Result<Self, ExtensionRegistryLoadError> {
96        let extensions: Vec<Extension> = serde_json::from_reader(reader)?;
97        // After deserialization, we need to update all the internal
98        // `Weak<Extension>` references.
99        Ok(ExtensionRegistry::new_with_extension_resolution(
100            extensions,
101            &other_extensions.into(),
102        )?)
103    }
104
105    /// Gets the Extension with the given name
106    pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
107        self.exts.get(name)
108    }
109
110    /// Returns `true` if the registry contains an extension with the given name.
111    pub fn contains(&self, name: &str) -> bool {
112        self.exts.contains_key(name)
113    }
114
115    /// Validate the set of extensions.
116    pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
117        if self.valid.load(Ordering::Relaxed) {
118            return Ok(());
119        }
120        for ext in self.exts.values() {
121            ext.validate()
122                .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
123        }
124        self.valid.store(true, Ordering::Relaxed);
125        Ok(())
126    }
127
128    /// Registers a new extension to the registry.
129    ///
130    /// Returns a reference to the registered extension if successful.
131    pub fn register(
132        &mut self,
133        extension: impl Into<Arc<Extension>>,
134    ) -> Result<(), ExtensionRegistryError> {
135        let extension = extension.into();
136        match self.exts.entry(extension.name().clone()) {
137            btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
138                extension.name().clone(),
139                Box::new(prev.get().version().clone()),
140                Box::new(extension.version().clone()),
141            )),
142            btree_map::Entry::Vacant(ve) => {
143                ve.insert(extension);
144                // Clear the valid flag so that the registry is re-validated.
145                self.valid.store(false, Ordering::Relaxed);
146
147                Ok(())
148            }
149        }
150    }
151
152    /// Registers a new extension to the registry, keeping the one most up to
153    /// date if the extension already exists.
154    ///
155    /// If extension IDs match, the extension with the higher version is kept.
156    /// If versions match, the original extension is kept. Returns a reference
157    /// to the registered extension if successful.
158    ///
159    /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary,
160    /// see [`ExtensionRegistry::register_updated_ref`].
161    pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
162        let extension = extension.into();
163        match self.exts.entry(extension.name().clone()) {
164            btree_map::Entry::Occupied(mut prev) => {
165                if prev.get().version() < extension.version() {
166                    *prev.get_mut() = extension;
167                }
168            }
169            btree_map::Entry::Vacant(ve) => {
170                ve.insert(extension);
171            }
172        }
173        // Clear the valid flag so that the registry is re-validated.
174        self.valid.store(false, Ordering::Relaxed);
175    }
176
177    /// Registers a new extension to the registry, keeping the one most up to
178    /// date if the extension already exists.
179    ///
180    /// If extension IDs match, the extension with the higher version is kept.
181    /// If versions match, the original extension is kept. Returns a reference
182    /// to the registered extension if successful.
183    ///
184    /// Clones the Arc only when required. For no-cloning version see
185    /// [`ExtensionRegistry::register_updated`].
186    pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
187        match self.exts.entry(extension.name().clone()) {
188            btree_map::Entry::Occupied(mut prev) => {
189                if prev.get().version() < extension.version() {
190                    *prev.get_mut() = extension.clone();
191                }
192            }
193            btree_map::Entry::Vacant(ve) => {
194                ve.insert(extension.clone());
195            }
196        }
197        // Clear the valid flag so that the registry is re-validated.
198        self.valid.store(false, Ordering::Relaxed);
199    }
200
201    /// Returns the number of extensions in the registry.
202    pub fn len(&self) -> usize {
203        self.exts.len()
204    }
205
206    /// Returns `true` if the registry contains no extensions.
207    pub fn is_empty(&self) -> bool {
208        self.exts.is_empty()
209    }
210
211    /// Returns an iterator over the extensions in the registry.
212    pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
213        self.exts.values()
214    }
215
216    /// Returns an iterator over the extensions ids in the registry.
217    pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
218        self.exts.keys()
219    }
220
221    /// Delete an extension from the registry and return it if it was present.
222    pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
223        // Clear the valid flag so that the registry is re-validated.
224        self.valid.store(false, Ordering::Relaxed);
225
226        self.exts.remove(name)
227    }
228
229    /// Constructs a new `ExtensionRegistry` from a list of [`Extension`]s while
230    /// giving you a [`WeakExtensionRegistry`] to the allocation. This allows
231    /// you to add [`Weak`] self-references to the [`Extension`]s while
232    /// constructing them, before wrapping them in [`Arc`]s.
233    ///
234    /// This is similar to [`Arc::new_cyclic`], but for `ExtensionRegistries`.
235    ///
236    /// Calling [`Weak::upgrade`] on a weak reference in the
237    /// [`WeakExtensionRegistry`] inside your closure will return an extension
238    /// with no internal (op / type / value) definitions.
239    //
240    // It may be possible to implement this safely using `Arc::new_cyclic`
241    // directly, but the callback type does not allow for returning extra
242    // data so it seems unlikely.
243    pub fn new_cyclic<F, E>(
244        extensions: impl IntoIterator<Item = Extension>,
245        init: F,
246    ) -> Result<Self, E>
247    where
248        F: FnOnce(Vec<Extension>, &WeakExtensionRegistry) -> Result<Vec<Extension>, E>,
249    {
250        let extensions = extensions.into_iter().collect_vec();
251
252        // Unsafe internally-mutable wrapper around an extension. Important:
253        // `repr(transparent)` ensures the layout is identical to `Extension`,
254        // so it can be safely transmuted.
255        #[repr(transparent)]
256        struct ExtensionCell {
257            ext: UnsafeCell<Extension>,
258        }
259
260        // Create the arcs with internal mutability, and collect weak references
261        // over immutable references.
262        //
263        // This is safe as long as the cell mutation happens when we can guarantee
264        // that the weak references are not used.
265        let (arcs, weaks): (Vec<Arc<ExtensionCell>>, Vec<Weak<Extension>>) = extensions
266            .iter()
267            .map(|ext| {
268                // Create a new arc with an empty extension sharing the name and version of the original,
269                // but with no internal definitions.
270                //
271                // `UnsafeCell` is not sync, but we are not writing to it while the weak references are
272                // being used.
273                #[allow(clippy::arc_with_non_send_sync)]
274                let arc = Arc::new(ExtensionCell {
275                    ext: UnsafeCell::new(Extension::new(ext.name().clone(), ext.version().clone())),
276                });
277
278                // SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
279                let weak_arc: Weak<Extension> = unsafe { mem::transmute(Arc::downgrade(&arc)) };
280                (arc, weak_arc)
281            })
282            .unzip();
283
284        let mut weak_registry = WeakExtensionRegistry::default();
285        for (ext, weak) in extensions.iter().zip(weaks) {
286            weak_registry.register(ext.name().clone(), weak);
287        }
288
289        // Actual initialization here
290        // Upgrading the weak references at any point here will access the empty extensions in the arcs.
291        let extensions = init(extensions, &weak_registry)?;
292
293        // We're done.
294        let arcs: Vec<Arc<Extension>> = arcs
295            .into_iter()
296            .zip(extensions)
297            .map(|(arc, ext)| {
298                // Replace the dummy extensions with the updated ones.
299                // SAFETY: The cell is only mutated when the weak references are not used.
300                unsafe { *arc.ext.get() = ext };
301                // Pretend the UnsafeCells never existed.
302                // SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
303                unsafe { mem::transmute::<Arc<ExtensionCell>, Arc<Extension>>(arc) }
304            })
305            .collect();
306        Ok(ExtensionRegistry::new(arcs))
307    }
308}
309
310impl IntoIterator for ExtensionRegistry {
311    type Item = Arc<Extension>;
312
313    type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;
314
315    fn into_iter(self) -> Self::IntoIter {
316        self.exts.into_values()
317    }
318}
319
320impl<'a> IntoIterator for &'a ExtensionRegistry {
321    type Item = &'a Arc<Extension>;
322
323    type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;
324
325    fn into_iter(self) -> Self::IntoIter {
326        self.exts.values()
327    }
328}
329
330impl<'a> Extend<&'a Arc<Extension>> for ExtensionRegistry {
331    fn extend<T: IntoIterator<Item = &'a Arc<Extension>>>(&mut self, iter: T) {
332        for ext in iter {
333            self.register_updated_ref(ext);
334        }
335    }
336}
337
338impl Extend<Arc<Extension>> for ExtensionRegistry {
339    fn extend<T: IntoIterator<Item = Arc<Extension>>>(&mut self, iter: T) {
340        for ext in iter {
341            self.register_updated(ext);
342        }
343    }
344}
345
346/// Encode/decode `ExtensionRegistry` as a list of extensions.
347///
348/// Any `Weak<Extension>` references inside the registry will be left unresolved.
349/// Prefer using [`ExtensionRegistry::load_json`] when deserializing.
350impl<'de> Deserialize<'de> for ExtensionRegistry {
351    fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
352    where
353        D: Deserializer<'de>,
354    {
355        let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
356        Ok(ExtensionRegistry::new(extensions))
357    }
358}
359
360impl Serialize for ExtensionRegistry {
361    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
362    where
363        S: serde::Serializer,
364    {
365        let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
366        extensions.serialize(serializer)
367    }
368}
369
370/// An Extension Registry containing no extensions.
371pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
372    exts: BTreeMap::new(),
373    valid: AtomicBool::new(true),
374};
375
376/// An error that can occur in computing the signature of a node.
377/// TODO: decide on failure modes
378#[derive(Debug, Clone, Error, PartialEq, Eq)]
379#[allow(missing_docs)]
380#[non_exhaustive]
381pub enum SignatureError {
382    /// Name mismatch
383    #[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
384    NameMismatch(TypeName, TypeName),
385    /// Extension mismatch
386    #[error("Definition extension ({0}) and instantiation extension ({1}) do not match.")]
387    ExtensionMismatch(ExtensionId, ExtensionId),
388    /// When the type arguments of the node did not match the params declared by the `OpDef`
389    #[error("Type arguments of node did not match params declared by definition: {0}")]
390    TypeArgMismatch(#[from] TermTypeError),
391    /// Invalid type arguments
392    #[error("Invalid type arguments for operation")]
393    InvalidTypeArgs,
394    /// The weak [`Extension`] reference for a custom type has been dropped.
395    #[error(
396        "Type '{typ}' is defined in extension '{missing}', but the extension reference has been dropped."
397    )]
398    MissingTypeExtension { typ: TypeName, missing: ExtensionId },
399    /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature
400    #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
401    ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
402    /// The bound recorded for a `CustomType` doesn't match what the `TypeDef` would compute
403    #[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
404    WrongBound {
405        actual: TypeBound,
406        expected: TypeBound,
407    },
408    /// A Type Variable's cache of its declared kind is incorrect
409    #[error("Type Variable claims to be {cached} but actual declaration {actual}")]
410    TypeVarDoesNotMatchDeclaration {
411        actual: Box<TypeParam>,
412        cached: Box<TypeParam>,
413    },
414    /// A type variable that was used has not been declared
415    #[error("Type variable {idx} was not declared ({num_decls} in scope)")]
416    FreeTypeVar { idx: usize, num_decls: usize },
417    /// A row variable was found outside of a variable-length row
418    #[error("Expected a single type, but found row variable {var}")]
419    RowVarWhereTypeExpected { var: RowVariable },
420    /// The result of the type application stored in a [Call]
421    /// is not what we get by applying the type-args to the polymorphic function
422    ///
423    /// [Call]: crate::ops::dataflow::Call
424    #[error(
425        "Incorrect result of type application in Call - cached {cached} but expected {expected}"
426    )]
427    CallIncorrectlyAppliesType {
428        cached: Box<Signature>,
429        expected: Box<Signature>,
430    },
431    /// The result of the type application stored in a [`LoadFunction`]
432    /// is not what we get by applying the type-args to the polymorphic function
433    ///
434    /// [`LoadFunction`]: crate::ops::dataflow::LoadFunction
435    #[error(
436        "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
437    )]
438    LoadFunctionIncorrectlyAppliesType {
439        cached: Box<Signature>,
440        expected: Box<Signature>,
441    },
442
443    /// Extension declaration specifies a binary compute signature function, but none
444    /// was loaded.
445    #[error("Binary compute signature function not loaded.")]
446    MissingComputeFunc,
447
448    /// Extension declaration specifies a binary compute signature function, but none
449    /// was loaded.
450    #[error("Binary validate signature function not loaded.")]
451    MissingValidateFunc,
452}
453
454/// Concrete instantiations of types and operations defined in extensions.
455trait CustomConcrete {
456    /// The identifier type for the concrete object.
457    type Identifier;
458    /// A generic identifier to the element.
459    ///
460    /// This may either refer to a [`TypeName`] or an [`OpName`].
461    fn def_name(&self) -> &Self::Identifier;
462    /// The concrete type arguments for the instantiation.
463    fn type_args(&self) -> &[TypeArg];
464    /// Extension required by the instantiation.
465    fn parent_extension(&self) -> &ExtensionId;
466}
467
468impl CustomConcrete for OpaqueOp {
469    type Identifier = OpName;
470
471    fn def_name(&self) -> &Self::Identifier {
472        self.unqualified_id()
473    }
474
475    fn type_args(&self) -> &[TypeArg] {
476        self.args()
477    }
478
479    fn parent_extension(&self) -> &ExtensionId {
480        self.extension()
481    }
482}
483
484impl CustomConcrete for CustomType {
485    type Identifier = TypeName;
486
487    fn def_name(&self) -> &TypeName {
488        // Casts the `TypeName` to a generic string.
489        self.name()
490    }
491
492    fn type_args(&self) -> &[TypeArg] {
493        self.args()
494    }
495
496    fn parent_extension(&self) -> &ExtensionId {
497        self.extension()
498    }
499}
500
501/// A unique identifier for a extension.
502///
503/// The actual [`Extension`] is stored externally.
504pub type ExtensionId = IdentList;
505
506/// A extension is a set of capabilities required to execute a graph.
507///
508/// These are normally defined once and shared across multiple graphs and
509/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`].
510///
511/// # Example
512///
513/// The following example demonstrates how to define a new extension with a
514/// custom operation and a custom type.
515///
516/// When using `arc`s, the extension can only be modified at creation time. The
517/// defined operations and types keep a [`Weak`] reference to their extension. We provide a
518/// helper method [`Extension::new_arc`] to aid their definition.
519///
520/// ```
521/// # use hugr_core::types::Signature;
522/// # use hugr_core::extension::{Extension, ExtensionId, Version};
523/// # use hugr_core::extension::{TypeDefBound};
524/// Extension::new_arc(
525///     ExtensionId::new_unchecked("my.extension"),
526///     Version::new(0, 1, 0),
527///     |ext, extension_ref| {
528///         // Add a custom type definition
529///         ext.add_type(
530///             "MyType".into(),
531///             vec![], // No type parameters
532///             "Some type".into(),
533///             TypeDefBound::any(),
534///             extension_ref,
535///         );
536///         // Add a custom operation
537///         ext.add_op(
538///             "MyOp".into(),
539///             "Some operation".into(),
540///             Signature::new_endo(vec![]),
541///             extension_ref,
542///         );
543///     },
544/// );
545/// ```
546#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
547pub struct Extension {
548    /// Extension version, follows semver.
549    pub version: Version,
550    /// Unique identifier for the extension.
551    pub name: ExtensionId,
552    /// Types defined by this extension.
553    types: BTreeMap<TypeName, TypeDef>,
554    /// Operation declarations with serializable definitions.
555    // Note: serde will serialize this because we configure with `features=["rc"]`.
556    // That will clone anything that has multiple references, but each
557    // OpDef should appear exactly once in this map (keyed by its name),
558    // and the other references to the OpDef are from ExternalOp's in the Hugr
559    // (which are serialized as OpaqueOp's i.e. Strings).
560    operations: BTreeMap<OpName, Arc<op_def::OpDef>>,
561}
562
563impl Extension {
564    /// Creates a new extension with the given name.
565    ///
566    /// In most cases extensions are contained inside an [`Arc`] so that they
567    /// can be shared across hugr instances and operation definitions.
568    ///
569    /// See [`Extension::new_arc`] for a more ergonomic way to create boxed
570    /// extensions.
571    #[must_use]
572    pub fn new(name: ExtensionId, version: Version) -> Self {
573        Self {
574            name,
575            version,
576            types: Default::default(),
577            operations: Default::default(),
578        }
579    }
580
581    /// Creates a new extension wrapped in an [`Arc`].
582    ///
583    /// The closure lets us use a weak reference to the arc while the extension
584    /// is being built. This is necessary for calling [`Extension::add_op`] and
585    /// [`Extension::add_type`].
586    pub fn new_arc(
587        name: ExtensionId,
588        version: Version,
589        init: impl FnOnce(&mut Extension, &Weak<Extension>),
590    ) -> Arc<Self> {
591        Arc::new_cyclic(|extension_ref| {
592            let mut ext = Self::new(name, version);
593            init(&mut ext, extension_ref);
594            ext
595        })
596    }
597
598    /// Creates a new extension wrapped in an [`Arc`], using a fallible
599    /// initialization function.
600    ///
601    /// The closure lets us use a weak reference to the arc while the extension
602    /// is being built. This is necessary for calling [`Extension::add_op`] and
603    /// [`Extension::add_type`].
604    pub fn try_new_arc<E>(
605        name: ExtensionId,
606        version: Version,
607        init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
608    ) -> Result<Arc<Self>, E> {
609        // Annoying hack around not having `Arc::try_new_cyclic` that can return
610        // a Result.
611        // https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381
612        //
613        // When there is an error, we store it in `error` and return it at the
614        // end instead of the partially-initialized extension.
615        let mut error = None;
616        let ext = Arc::new_cyclic(|extension_ref| {
617            let mut ext = Self::new(name, version);
618            match init(&mut ext, extension_ref) {
619                Ok(()) => ext,
620                Err(e) => {
621                    error = Some(e);
622                    ext
623                }
624            }
625        });
626        match error {
627            Some(e) => Err(e),
628            None => Ok(ext),
629        }
630    }
631
632    /// Allows read-only access to the operations in this Extension
633    #[must_use]
634    pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
635        self.operations.get(name)
636    }
637
638    /// Allows read-only access to the types in this Extension
639    #[must_use]
640    pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
641        self.types.get(type_name)
642    }
643
644    /// Returns the name of the extension.
645    #[must_use]
646    pub fn name(&self) -> &ExtensionId {
647        &self.name
648    }
649
650    /// Returns the version of the extension.
651    #[must_use]
652    pub fn version(&self) -> &Version {
653        &self.version
654    }
655
656    /// Iterator over the operations of this [`Extension`].
657    pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
658        self.operations.iter()
659    }
660
661    /// Iterator over the types of this [`Extension`].
662    pub fn types(&self) -> impl Iterator<Item = (&TypeName, &TypeDef)> {
663        self.types.iter()
664    }
665
666    /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension.
667    pub fn instantiate_extension_op(
668        &self,
669        name: &OpNameRef,
670        args: impl Into<Vec<TypeArg>>,
671    ) -> Result<ExtensionOp, SignatureError> {
672        let op_def = self.get_op(name).expect("Op not found.");
673        ExtensionOp::new(op_def.clone(), args)
674    }
675
676    /// Validates the operation definitions in the register.
677    fn validate(&self) -> Result<(), SignatureError> {
678        // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624
679        for op_def in self.operations.values() {
680            op_def.validate()?;
681        }
682        Ok(())
683    }
684}
685
686impl PartialEq for Extension {
687    fn eq(&self, other: &Self) -> bool {
688        self.name == other.name && self.version == other.version
689    }
690}
691
692/// An error that can occur in defining an extension registry.
693#[derive(Debug, Clone, Error, PartialEq, Eq)]
694#[non_exhaustive]
695pub enum ExtensionRegistryError {
696    /// Extension already defined.
697    #[error(
698        "The registry already contains an extension with id {0} and version {1}. New extension has version {2}."
699    )]
700    AlreadyRegistered(ExtensionId, Box<Version>, Box<Version>),
701    /// A registered extension has invalid signatures.
702    #[error("The extension {0} contains an invalid signature, {1}.")]
703    InvalidSignature(ExtensionId, #[source] SignatureError),
704}
705
706/// An error that can occur while loading an extension registry.
707#[derive(Debug, Error)]
708#[non_exhaustive]
709#[error("Extension registry load error")]
710pub enum ExtensionRegistryLoadError {
711    /// Deserialization error.
712    #[error(transparent)]
713    SerdeError(#[from] serde_json::Error),
714    /// Error when resolving internal extension references.
715    #[error(transparent)]
716    ExtensionResolutionError(Box<ExtensionResolutionError>),
717}
718
719impl From<ExtensionResolutionError> for ExtensionRegistryLoadError {
720    fn from(error: ExtensionResolutionError) -> Self {
721        Self::ExtensionResolutionError(Box::new(error))
722    }
723}
724
725/// An error that can occur in building a new extension.
726#[derive(Debug, Clone, Error, PartialEq, Eq)]
727#[non_exhaustive]
728pub enum ExtensionBuildError {
729    /// Existing [`OpDef`]
730    #[error("Extension already has an op called {0}.")]
731    OpDefExists(OpName),
732    /// Existing [`TypeDef`]
733    #[error("Extension already has an type called {0}.")]
734    TypeDefExists(TypeName),
735}
736
737/// A set of extensions identified by their unique [`ExtensionId`].
738#[derive(
739    Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
740)]
741#[display("[{}]", _0.iter().join(", "))]
742pub struct ExtensionSet(BTreeSet<ExtensionId>);
743
744impl ExtensionSet {
745    /// Creates a new empty extension set.
746    #[must_use]
747    pub const fn new() -> Self {
748        Self(BTreeSet::new())
749    }
750
751    /// Adds a extension to the set.
752    pub fn insert(&mut self, extension: ExtensionId) {
753        self.0.insert(extension.clone());
754    }
755
756    /// Returns `true` if the set contains the given extension.
757    #[must_use]
758    pub fn contains(&self, extension: &ExtensionId) -> bool {
759        self.0.contains(extension)
760    }
761
762    /// Returns `true` if the set is a subset of `other`.
763    #[must_use]
764    pub fn is_subset(&self, other: &Self) -> bool {
765        self.0.is_subset(&other.0)
766    }
767
768    /// Returns `true` if the set is a superset of `other`.
769    #[must_use]
770    pub fn is_superset(&self, other: &Self) -> bool {
771        self.0.is_superset(&other.0)
772    }
773
774    /// Create a extension set with a single element.
775    #[must_use]
776    pub fn singleton(extension: ExtensionId) -> Self {
777        let mut set = Self::new();
778        set.insert(extension);
779        set
780    }
781
782    /// Returns the union of two extension sets.
783    #[must_use]
784    pub fn union(mut self, other: Self) -> Self {
785        self.0.extend(other.0);
786        self
787    }
788
789    /// Returns the union of an arbitrary collection of [`ExtensionSet`]s
790    pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
791        // `union` clones the receiver, which we do not need to do here
792        let mut res = ExtensionSet::new();
793        for s in sets {
794            res.0.extend(s.0);
795        }
796        res
797    }
798
799    /// The things in other which are in not in self
800    #[must_use]
801    pub fn missing_from(&self, other: &Self) -> Self {
802        ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
803    }
804
805    /// Iterate over the contained `ExtensionIds`
806    pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
807        self.0.iter()
808    }
809
810    /// True if this set contains no [`ExtensionId`]s
811    #[must_use]
812    pub fn is_empty(&self) -> bool {
813        self.0.is_empty()
814    }
815}
816
817impl From<ExtensionId> for ExtensionSet {
818    fn from(id: ExtensionId) -> Self {
819        Self::singleton(id)
820    }
821}
822
823impl IntoIterator for ExtensionSet {
824    type Item = ExtensionId;
825    type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;
826
827    fn into_iter(self) -> Self::IntoIter {
828        self.0.into_iter()
829    }
830}
831
832impl<'a> IntoIterator for &'a ExtensionSet {
833    type Item = &'a ExtensionId;
834    type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;
835
836    fn into_iter(self) -> Self::IntoIter {
837        self.0.iter()
838    }
839}
840
841impl FromIterator<ExtensionId> for ExtensionSet {
842    fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
843        Self(BTreeSet::from_iter(iter))
844    }
845}
846
847/// Extension tests.
848#[cfg(test)]
849pub mod test {
850    // We re-export this here because mod op_def is private.
851    pub use super::op_def::test::SimpleOpDef;
852
853    use super::*;
854
855    impl Extension {
856        /// Create a new extension for testing, with a 0 version.
857        pub(crate) fn new_test_arc(
858            name: ExtensionId,
859            init: impl FnOnce(&mut Extension, &Weak<Extension>),
860        ) -> Arc<Self> {
861            Self::new_arc(name, Version::new(0, 0, 0), init)
862        }
863
864        /// Create a new extension for testing, with a 0 version.
865        pub(crate) fn try_new_test_arc(
866            name: ExtensionId,
867            init: impl FnOnce(
868                &mut Extension,
869                &Weak<Extension>,
870            ) -> Result<(), Box<dyn std::error::Error>>,
871        ) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
872            Self::try_new_arc(name, Version::new(0, 0, 0), init)
873        }
874    }
875
876    #[test]
877    fn test_register_update() {
878        // Two registers that should remain the same.
879        // We use them to test both `register_updated` and `register_updated_ref`.
880        let mut reg = ExtensionRegistry::default();
881        let mut reg_ref = ExtensionRegistry::default();
882
883        let ext_1_id = ExtensionId::new("ext1").unwrap();
884        let ext_2_id = ExtensionId::new("ext2").unwrap();
885        let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
886        let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
887        let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
888        let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));
889
890        reg.register(ext1.clone()).unwrap();
891        reg_ref.register(ext1.clone()).unwrap();
892        assert_eq!(&reg, &reg_ref);
893
894        // normal registration fails
895        assert_eq!(
896            reg.register(ext1_1.clone()),
897            Err(ExtensionRegistryError::AlreadyRegistered(
898                ext_1_id.clone(),
899                Box::new(Version::new(1, 0, 0)),
900                Box::new(Version::new(1, 1, 0))
901            ))
902        );
903
904        // register with update works
905        reg_ref.register_updated_ref(&ext1_1);
906        reg.register_updated(ext1_1.clone());
907        assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
908        assert_eq!(&reg, &reg_ref);
909
910        // register with lower version does not change version
911        reg_ref.register_updated_ref(&ext1_2);
912        reg.register_updated(ext1_2.clone());
913        assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
914        assert_eq!(&reg, &reg_ref);
915
916        reg.register(ext2.clone()).unwrap();
917        assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
918        assert_eq!(reg.len(), 2);
919
920        assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
921        assert_eq!(reg.len(), 1);
922    }
923
924    mod proptest {
925
926        use ::proptest::{collection::hash_set, prelude::*};
927
928        use super::super::{ExtensionId, ExtensionSet};
929
930        impl Arbitrary for ExtensionSet {
931            type Parameters = ();
932            type Strategy = BoxedStrategy<Self>;
933
934            fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
935                hash_set(any::<ExtensionId>(), 0..3)
936                    .prop_map(|extensions| extensions.into_iter().collect::<ExtensionSet>())
937                    .boxed()
938            }
939        }
940    }
941}