hugr_core/
extension.rs

1//! Extensions
2//!
3//! TODO: YAML declaration and parsing. This should be similar to a plugin
4//! system (outside the `types` module), which also parses nested [`OpDef`]s.
5
6use itertools::Itertools;
7use resolution::{ExtensionResolutionError, WeakExtensionRegistry};
8pub use semver::Version;
9use serde::{Deserialize, Deserializer, Serialize};
10use std::cell::UnsafeCell;
11use std::collections::btree_map;
12use std::collections::{BTreeMap, BTreeSet};
13use std::fmt::Debug;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::{Arc, Weak};
16use std::{io, mem};
17
18use derive_more::Display;
19use thiserror::Error;
20
21use crate::hugr::IdentList;
22use crate::ops::constant::{ValueName, ValueNameRef};
23use crate::ops::custom::{ExtensionOp, OpaqueOp};
24use crate::ops::{self, OpName, OpNameRef};
25use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
26use crate::types::RowVariable;
27use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName};
28use crate::types::{Signature, TypeNameRef};
29
30mod const_fold;
31mod op_def;
32pub mod prelude;
33pub mod resolution;
34pub mod simple_op;
35mod type_def;
36
37pub use const_fold::{fold_out_row, ConstFold, ConstFoldResult, Folder};
38pub use op_def::{
39    CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
40    ValidateJustArgs, ValidateTypeArgs,
41};
42pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
43pub use type_def::{TypeDef, TypeDefBound};
44
45#[cfg(feature = "declarative")]
46pub mod declarative;
47
48/// Extension Registries store extensions to be looked up e.g. during validation.
49#[derive(Debug, Display, Default)]
50#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
51pub struct ExtensionRegistry {
52    /// The extensions in the registry.
53    exts: BTreeMap<ExtensionId, Arc<Extension>>,
54    /// A flag indicating whether the current set of extensions has been
55    /// validated.
56    ///
57    /// This is used to avoid re-validating the extensions every time the
58    /// registry is validated, and is set to `false` whenever a new extension is
59    /// added.
60    valid: AtomicBool,
61}
62
63impl PartialEq for ExtensionRegistry {
64    fn eq(&self, other: &Self) -> bool {
65        self.exts == other.exts
66    }
67}
68
69impl Clone for ExtensionRegistry {
70    fn clone(&self) -> Self {
71        Self {
72            exts: self.exts.clone(),
73            valid: self.valid.load(Ordering::Relaxed).into(),
74        }
75    }
76}
77
78impl ExtensionRegistry {
79    /// Create a new empty extension registry.
80    pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
81        let mut res = Self::default();
82        for ext in extensions.into_iter() {
83            res.register_updated(ext);
84        }
85        res
86    }
87
88    /// Load an ExtensionRegistry serialized as json.
89    ///
90    /// After deserialization, updates all the internal `Weak<Extension>`
91    /// references to point to the newly created [`Arc`]s in the registry,
92    /// or extensions in the `additional_extensions` parameter.
93    pub fn load_json(
94        reader: impl io::Read,
95        other_extensions: &ExtensionRegistry,
96    ) -> Result<Self, ExtensionRegistryLoadError> {
97        let extensions: Vec<Extension> = serde_json::from_reader(reader)?;
98        // After deserialization, we need to update all the internal
99        // `Weak<Extension>` references.
100        Ok(ExtensionRegistry::new_with_extension_resolution(
101            extensions,
102            &other_extensions.into(),
103        )?)
104    }
105
106    /// Gets the Extension with the given name
107    pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
108        self.exts.get(name)
109    }
110
111    /// Returns `true` if the registry contains an extension with the given name.
112    pub fn contains(&self, name: &str) -> bool {
113        self.exts.contains_key(name)
114    }
115
116    /// Validate the set of extensions.
117    pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
118        if self.valid.load(Ordering::Relaxed) {
119            return Ok(());
120        }
121        for ext in self.exts.values() {
122            ext.validate()
123                .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
124        }
125        self.valid.store(true, Ordering::Relaxed);
126        Ok(())
127    }
128
129    /// Registers a new extension to the registry.
130    ///
131    /// Returns a reference to the registered extension if successful.
132    pub fn register(
133        &mut self,
134        extension: impl Into<Arc<Extension>>,
135    ) -> Result<(), ExtensionRegistryError> {
136        let extension = extension.into();
137        match self.exts.entry(extension.name().clone()) {
138            btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
139                extension.name().clone(),
140                prev.get().version().clone(),
141                extension.version().clone(),
142            )),
143            btree_map::Entry::Vacant(ve) => {
144                ve.insert(extension);
145                // Clear the valid flag so that the registry is re-validated.
146                self.valid.store(false, Ordering::Relaxed);
147
148                Ok(())
149            }
150        }
151    }
152
153    /// Registers a new extension to the registry, keeping the one most up to
154    /// date if the extension already exists.
155    ///
156    /// If extension IDs match, the extension with the higher version is kept.
157    /// If versions match, the original extension is kept. Returns a reference
158    /// to the registered extension if successful.
159    ///
160    /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary,
161    /// see [`ExtensionRegistry::register_updated_ref`].
162    pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
163        let extension = extension.into();
164        match self.exts.entry(extension.name().clone()) {
165            btree_map::Entry::Occupied(mut prev) => {
166                if prev.get().version() < extension.version() {
167                    *prev.get_mut() = extension;
168                }
169            }
170            btree_map::Entry::Vacant(ve) => {
171                ve.insert(extension);
172            }
173        }
174        // Clear the valid flag so that the registry is re-validated.
175        self.valid.store(false, Ordering::Relaxed);
176    }
177
178    /// Registers a new extension to the registry, keeping the one most up to
179    /// date if the extension already exists.
180    ///
181    /// If extension IDs match, the extension with the higher version is kept.
182    /// If versions match, the original extension is kept. Returns a reference
183    /// to the registered extension if successful.
184    ///
185    /// Clones the Arc only when required. For no-cloning version see
186    /// [`ExtensionRegistry::register_updated`].
187    pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
188        match self.exts.entry(extension.name().clone()) {
189            btree_map::Entry::Occupied(mut prev) => {
190                if prev.get().version() < extension.version() {
191                    *prev.get_mut() = extension.clone();
192                }
193            }
194            btree_map::Entry::Vacant(ve) => {
195                ve.insert(extension.clone());
196            }
197        }
198        // Clear the valid flag so that the registry is re-validated.
199        self.valid.store(false, Ordering::Relaxed);
200    }
201
202    /// Returns the number of extensions in the registry.
203    pub fn len(&self) -> usize {
204        self.exts.len()
205    }
206
207    /// Returns `true` if the registry contains no extensions.
208    pub fn is_empty(&self) -> bool {
209        self.exts.is_empty()
210    }
211
212    /// Returns an iterator over the extensions in the registry.
213    pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
214        self.exts.values()
215    }
216
217    /// Returns an iterator over the extensions ids in the registry.
218    pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
219        self.exts.keys()
220    }
221
222    /// Delete an extension from the registry and return it if it was present.
223    pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
224        // Clear the valid flag so that the registry is re-validated.
225        self.valid.store(false, Ordering::Relaxed);
226
227        self.exts.remove(name)
228    }
229
230    /// Constructs a new ExtensionRegistry from a list of [`Extension`]s while
231    /// giving you a [`WeakExtensionRegistry`] to the allocation. This allows
232    /// you to add [`Weak`] self-references to the [`Extension`]s while
233    /// constructing them, before wrapping them in [`Arc`]s.
234    ///
235    /// This is similar to [`Arc::new_cyclic`], but for ExtensionRegistries.
236    ///
237    /// Calling [`Weak::upgrade`] on a weak reference in the
238    /// [`WeakExtensionRegistry`] inside your closure will return an extension
239    /// with no internal (op / type / value) definitions.
240    //
241    // It may be possible to implement this safely using `Arc::new_cyclic`
242    // directly, but the callback type does not allow for returning extra
243    // data so it seems unlikely.
244    pub fn new_cyclic<F, E>(
245        extensions: impl IntoIterator<Item = Extension>,
246        init: F,
247    ) -> Result<Self, E>
248    where
249        F: FnOnce(Vec<Extension>, &WeakExtensionRegistry) -> Result<Vec<Extension>, E>,
250    {
251        let extensions = extensions.into_iter().collect_vec();
252
253        // Unsafe internally-mutable wrapper around an extension. Important:
254        // `repr(transparent)` ensures the layout is identical to `Extension`,
255        // so it can be safely transmuted.
256        #[repr(transparent)]
257        struct ExtensionCell {
258            ext: UnsafeCell<Extension>,
259        }
260
261        // Create the arcs with internal mutability, and collect weak references
262        // over immutable references.
263        //
264        // This is safe as long as the cell mutation happens when we can guarantee
265        // that the weak references are not used.
266        let (arcs, weaks): (Vec<Arc<ExtensionCell>>, Vec<Weak<Extension>>) = extensions
267            .iter()
268            .map(|ext| {
269                // Create a new arc with an empty extension sharing the name and version of the original,
270                // but with no internal definitions.
271                //
272                // `UnsafeCell` is not sync, but we are not writing to it while the weak references are
273                // being used.
274                #[allow(clippy::arc_with_non_send_sync)]
275                let arc = Arc::new(ExtensionCell {
276                    ext: UnsafeCell::new(Extension::new(ext.name().clone(), ext.version().clone())),
277                });
278
279                // SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
280                let weak_arc: Weak<Extension> = unsafe { mem::transmute(Arc::downgrade(&arc)) };
281                (arc, weak_arc)
282            })
283            .unzip();
284
285        let mut weak_registry = WeakExtensionRegistry::default();
286        for (ext, weak) in extensions.iter().zip(weaks) {
287            weak_registry.register(ext.name().clone(), weak);
288        }
289
290        // Actual initialization here
291        // Upgrading the weak references at any point here will access the empty extensions in the arcs.
292        let extensions = init(extensions, &weak_registry)?;
293
294        // We're done.
295        let arcs: Vec<Arc<Extension>> = arcs
296            .into_iter()
297            .zip(extensions)
298            .map(|(arc, ext)| {
299                // Replace the dummy extensions with the updated ones.
300                // SAFETY: The cell is only mutated when the weak references are not used.
301                unsafe { *arc.ext.get() = ext };
302                // Pretend the UnsafeCells never existed.
303                // SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
304                unsafe { mem::transmute::<Arc<ExtensionCell>, Arc<Extension>>(arc) }
305            })
306            .collect();
307        Ok(ExtensionRegistry::new(arcs))
308    }
309}
310
311impl IntoIterator for ExtensionRegistry {
312    type Item = Arc<Extension>;
313
314    type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;
315
316    fn into_iter(self) -> Self::IntoIter {
317        self.exts.into_values()
318    }
319}
320
321impl<'a> IntoIterator for &'a ExtensionRegistry {
322    type Item = &'a Arc<Extension>;
323
324    type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;
325
326    fn into_iter(self) -> Self::IntoIter {
327        self.exts.values()
328    }
329}
330
331impl<'a> Extend<&'a Arc<Extension>> for ExtensionRegistry {
332    fn extend<T: IntoIterator<Item = &'a Arc<Extension>>>(&mut self, iter: T) {
333        for ext in iter {
334            self.register_updated_ref(ext);
335        }
336    }
337}
338
339impl Extend<Arc<Extension>> for ExtensionRegistry {
340    fn extend<T: IntoIterator<Item = Arc<Extension>>>(&mut self, iter: T) {
341        for ext in iter {
342            self.register_updated(ext);
343        }
344    }
345}
346
347/// Encode/decode ExtensionRegistry as a list of extensions.
348///
349/// Any `Weak<Extension>` references inside the registry will be left unresolved.
350/// Prefer using [`ExtensionRegistry::load_json`] when deserializing.
351impl<'de> Deserialize<'de> for ExtensionRegistry {
352    fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
353    where
354        D: Deserializer<'de>,
355    {
356        let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
357        Ok(ExtensionRegistry::new(extensions))
358    }
359}
360
361impl Serialize for ExtensionRegistry {
362    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
363    where
364        S: serde::Serializer,
365    {
366        let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
367        extensions.serialize(serializer)
368    }
369}
370
371/// An Extension Registry containing no extensions.
372pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
373    exts: BTreeMap::new(),
374    valid: AtomicBool::new(true),
375};
376
377/// An error that can occur in computing the signature of a node.
378/// TODO: decide on failure modes
379#[derive(Debug, Clone, Error, PartialEq, Eq)]
380#[allow(missing_docs)]
381pub enum SignatureError {
382    /// Name mismatch
383    #[error("Definition name ({0}) and instantiation name ({1}) do not match.")]
384    NameMismatch(TypeName, TypeName),
385    /// Extension mismatch
386    #[error("Definition extension ({0}) and instantiation extension ({1}) do not match.")]
387    ExtensionMismatch(ExtensionId, ExtensionId),
388    /// When the type arguments of the node did not match the params declared by the OpDef
389    #[error("Type arguments of node did not match params declared by definition: {0}")]
390    TypeArgMismatch(#[from] TypeArgError),
391    /// Invalid type arguments
392    #[error("Invalid type arguments for operation")]
393    InvalidTypeArgs,
394    /// The weak [`Extension`] reference for a custom type has been dropped.
395    #[error("Type '{typ}' is defined in extension '{missing}', but the extension reference has been dropped.")]
396    MissingTypeExtension { typ: TypeName, missing: ExtensionId },
397    /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature
398    #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")]
399    ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName },
400    /// The bound recorded for a CustomType doesn't match what the TypeDef would compute
401    #[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")]
402    WrongBound {
403        actual: TypeBound,
404        expected: TypeBound,
405    },
406    /// A Type Variable's cache of its declared kind is incorrect
407    #[error("Type Variable claims to be {cached} but actual declaration {actual}")]
408    TypeVarDoesNotMatchDeclaration {
409        actual: TypeParam,
410        cached: TypeParam,
411    },
412    /// A type variable that was used has not been declared
413    #[error("Type variable {idx} was not declared ({num_decls} in scope)")]
414    FreeTypeVar { idx: usize, num_decls: usize },
415    /// A row variable was found outside of a variable-length row
416    #[error("Expected a single type, but found row variable {var}")]
417    RowVarWhereTypeExpected { var: RowVariable },
418    /// The result of the type application stored in a [Call]
419    /// is not what we get by applying the type-args to the polymorphic function
420    ///
421    /// [Call]: crate::ops::dataflow::Call
422    #[error(
423        "Incorrect result of type application in Call - cached {cached} but expected {expected}"
424    )]
425    CallIncorrectlyAppliesType {
426        cached: Signature,
427        expected: Signature,
428    },
429    /// The result of the type application stored in a [LoadFunction]
430    /// is not what we get by applying the type-args to the polymorphic function
431    ///
432    /// [LoadFunction]: crate::ops::dataflow::LoadFunction
433    #[error(
434        "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
435    )]
436    LoadFunctionIncorrectlyAppliesType {
437        cached: Signature,
438        expected: Signature,
439    },
440
441    /// Extension declaration specifies a binary compute signature function, but none
442    /// was loaded.
443    #[error("Binary compute signature function not loaded.")]
444    MissingComputeFunc,
445
446    /// Extension declaration specifies a binary compute signature function, but none
447    /// was loaded.
448    #[error("Binary validate signature function not loaded.")]
449    MissingValidateFunc,
450}
451
452/// Concrete instantiations of types and operations defined in extensions.
453trait CustomConcrete {
454    /// The identifier type for the concrete object.
455    type Identifier;
456    /// A generic identifier to the element.
457    ///
458    /// This may either refer to a [`TypeName`] or an [`OpName`].
459    fn def_name(&self) -> &Self::Identifier;
460    /// The concrete type arguments for the instantiation.
461    fn type_args(&self) -> &[TypeArg];
462    /// Extension required by the instantiation.
463    fn parent_extension(&self) -> &ExtensionId;
464}
465
466impl CustomConcrete for OpaqueOp {
467    type Identifier = OpName;
468
469    fn def_name(&self) -> &OpName {
470        self.op_name()
471    }
472
473    fn type_args(&self) -> &[TypeArg] {
474        self.args()
475    }
476
477    fn parent_extension(&self) -> &ExtensionId {
478        self.extension()
479    }
480}
481
482impl CustomConcrete for CustomType {
483    type Identifier = TypeName;
484
485    fn def_name(&self) -> &TypeName {
486        // Casts the `TypeName` to a generic string.
487        self.name()
488    }
489
490    fn type_args(&self) -> &[TypeArg] {
491        self.args()
492    }
493
494    fn parent_extension(&self) -> &ExtensionId {
495        self.extension()
496    }
497}
498
499/// A constant value provided by a extension.
500/// Must be an instance of a type available to the extension.
501#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
502pub struct ExtensionValue {
503    extension: ExtensionId,
504    name: ValueName,
505    typed_value: ops::Value,
506}
507
508impl ExtensionValue {
509    /// Returns a reference to the typed value of this [`ExtensionValue`].
510    pub fn typed_value(&self) -> &ops::Value {
511        &self.typed_value
512    }
513
514    /// Returns a mutable reference to the typed value of this [`ExtensionValue`].
515    pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value {
516        &mut self.typed_value
517    }
518
519    /// Returns a reference to the name of this [`ExtensionValue`].
520    pub fn name(&self) -> &str {
521        self.name.as_str()
522    }
523
524    /// Returns a reference to the extension this [`ExtensionValue`] belongs to.
525    pub fn extension(&self) -> &ExtensionId {
526        &self.extension
527    }
528}
529
530/// A unique identifier for a extension.
531///
532/// The actual [`Extension`] is stored externally.
533pub type ExtensionId = IdentList;
534
535/// A extension is a set of capabilities required to execute a graph.
536///
537/// These are normally defined once and shared across multiple graphs and
538/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`].
539///
540/// # Example
541///
542/// The following example demonstrates how to define a new extension with a
543/// custom operation and a custom type.
544///
545/// When using `arc`s, the extension can only be modified at creation time. The
546/// defined operations and types keep a [`Weak`] reference to their extension. We provide a
547/// helper method [`Extension::new_arc`] to aid their definition.
548///
549/// ```
550/// # use hugr_core::types::Signature;
551/// # use hugr_core::extension::{Extension, ExtensionId, Version};
552/// # use hugr_core::extension::{TypeDefBound};
553/// Extension::new_arc(
554///     ExtensionId::new_unchecked("my.extension"),
555///     Version::new(0, 1, 0),
556///     |ext, extension_ref| {
557///         // Add a custom type definition
558///         ext.add_type(
559///             "MyType".into(),
560///             vec![], // No type parameters
561///             "Some type".into(),
562///             TypeDefBound::any(),
563///             extension_ref,
564///         );
565///         // Add a custom operation
566///         ext.add_op(
567///             "MyOp".into(),
568///             "Some operation".into(),
569///             Signature::new_endo(vec![]),
570///             extension_ref,
571///         );
572///     },
573/// );
574/// ```
575#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
576pub struct Extension {
577    /// Extension version, follows semver.
578    pub version: Version,
579    /// Unique identifier for the extension.
580    pub name: ExtensionId,
581    /// Runtime dependencies this extension has on other extensions.
582    pub runtime_reqs: ExtensionSet,
583    /// Types defined by this extension.
584    types: BTreeMap<TypeName, TypeDef>,
585    /// Static values defined by this extension.
586    values: BTreeMap<ValueName, ExtensionValue>,
587    /// Operation declarations with serializable definitions.
588    // Note: serde will serialize this because we configure with `features=["rc"]`.
589    // That will clone anything that has multiple references, but each
590    // OpDef should appear exactly once in this map (keyed by its name),
591    // and the other references to the OpDef are from ExternalOp's in the Hugr
592    // (which are serialized as OpaqueOp's i.e. Strings).
593    operations: BTreeMap<OpName, Arc<op_def::OpDef>>,
594}
595
596impl Extension {
597    /// Creates a new extension with the given name.
598    ///
599    /// In most cases extensions are contained inside an [`Arc`] so that they
600    /// can be shared across hugr instances and operation definitions.
601    ///
602    /// See [`Extension::new_arc`] for a more ergonomic way to create boxed
603    /// extensions.
604    pub fn new(name: ExtensionId, version: Version) -> Self {
605        Self {
606            name,
607            version,
608            runtime_reqs: Default::default(),
609            types: Default::default(),
610            values: Default::default(),
611            operations: Default::default(),
612        }
613    }
614
615    /// Creates a new extension wrapped in an [`Arc`].
616    ///
617    /// The closure lets us use a weak reference to the arc while the extension
618    /// is being built. This is necessary for calling [`Extension::add_op`] and
619    /// [`Extension::add_type`].
620    pub fn new_arc(
621        name: ExtensionId,
622        version: Version,
623        init: impl FnOnce(&mut Extension, &Weak<Extension>),
624    ) -> Arc<Self> {
625        Arc::new_cyclic(|extension_ref| {
626            let mut ext = Self::new(name, version);
627            init(&mut ext, extension_ref);
628            ext
629        })
630    }
631
632    /// Creates a new extension wrapped in an [`Arc`], using a fallible
633    /// initialization function.
634    ///
635    /// The closure lets us use a weak reference to the arc while the extension
636    /// is being built. This is necessary for calling [`Extension::add_op`] and
637    /// [`Extension::add_type`].
638    pub fn try_new_arc<E>(
639        name: ExtensionId,
640        version: Version,
641        init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
642    ) -> Result<Arc<Self>, E> {
643        // Annoying hack around not having `Arc::try_new_cyclic` that can return
644        // a Result.
645        // https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381
646        //
647        // When there is an error, we store it in `error` and return it at the
648        // end instead of the partially-initialized extension.
649        let mut error = None;
650        let ext = Arc::new_cyclic(|extension_ref| {
651            let mut ext = Self::new(name, version);
652            match init(&mut ext, extension_ref) {
653                Ok(_) => ext,
654                Err(e) => {
655                    error = Some(e);
656                    ext
657                }
658            }
659        });
660        match error {
661            Some(e) => Err(e),
662            None => Ok(ext),
663        }
664    }
665
666    /// Extend the runtime requirements of this extension with another set of extensions.
667    pub fn add_requirements(&mut self, runtime_reqs: impl Into<ExtensionSet>) {
668        let reqs = mem::take(&mut self.runtime_reqs);
669        self.runtime_reqs = reqs.union(runtime_reqs.into());
670    }
671
672    /// Allows read-only access to the operations in this Extension
673    pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
674        self.operations.get(name)
675    }
676
677    /// Allows read-only access to the types in this Extension
678    pub fn get_type(&self, type_name: &TypeNameRef) -> Option<&type_def::TypeDef> {
679        self.types.get(type_name)
680    }
681
682    /// Allows read-only access to the values in this Extension
683    pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> {
684        self.values.get(value_name)
685    }
686
687    /// Returns the name of the extension.
688    pub fn name(&self) -> &ExtensionId {
689        &self.name
690    }
691
692    /// Returns the version of the extension.
693    pub fn version(&self) -> &Version {
694        &self.version
695    }
696
697    /// Iterator over the operations of this [`Extension`].
698    pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
699        self.operations.iter()
700    }
701
702    /// Iterator over the types of this [`Extension`].
703    pub fn types(&self) -> impl Iterator<Item = (&TypeName, &TypeDef)> {
704        self.types.iter()
705    }
706
707    /// Add a named static value to the extension.
708    pub fn add_value(
709        &mut self,
710        name: impl Into<ValueName>,
711        typed_value: ops::Value,
712    ) -> Result<&mut ExtensionValue, ExtensionBuildError> {
713        let extension_value = ExtensionValue {
714            extension: self.name.clone(),
715            name: name.into(),
716            typed_value,
717        };
718        match self.values.entry(extension_value.name.clone()) {
719            btree_map::Entry::Occupied(_) => {
720                Err(ExtensionBuildError::ValueExists(extension_value.name))
721            }
722            btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)),
723        }
724    }
725
726    /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension.
727    pub fn instantiate_extension_op(
728        &self,
729        name: &OpNameRef,
730        args: impl Into<Vec<TypeArg>>,
731    ) -> Result<ExtensionOp, SignatureError> {
732        let op_def = self.get_op(name).expect("Op not found.");
733        ExtensionOp::new(op_def.clone(), args)
734    }
735
736    /// Validates the operation definitions in the register.
737    fn validate(&self) -> Result<(), SignatureError> {
738        // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624
739        for op_def in self.operations.values() {
740            op_def.validate()?;
741        }
742        Ok(())
743    }
744}
745
746impl PartialEq for Extension {
747    fn eq(&self, other: &Self) -> bool {
748        self.name == other.name && self.version == other.version
749    }
750}
751
752/// An error that can occur in defining an extension registry.
753#[derive(Debug, Clone, Error, PartialEq, Eq)]
754#[non_exhaustive]
755pub enum ExtensionRegistryError {
756    /// Extension already defined.
757    #[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")]
758    AlreadyRegistered(ExtensionId, Version, Version),
759    /// A registered extension has invalid signatures.
760    #[error("The extension {0} contains an invalid signature, {1}.")]
761    InvalidSignature(ExtensionId, #[source] SignatureError),
762}
763
764/// An error that can occur while loading an extension registry.
765#[derive(Debug, Error)]
766#[non_exhaustive]
767pub enum ExtensionRegistryLoadError {
768    /// Deserialization error.
769    #[error(transparent)]
770    SerdeError(#[from] serde_json::Error),
771    /// Error when resolving internal extension references.
772    #[error(transparent)]
773    ExtensionResolutionError(#[from] ExtensionResolutionError),
774}
775
776/// An error that can occur in building a new extension.
777#[derive(Debug, Clone, Error, PartialEq, Eq)]
778#[non_exhaustive]
779pub enum ExtensionBuildError {
780    /// Existing [`OpDef`]
781    #[error("Extension already has an op called {0}.")]
782    OpDefExists(OpName),
783    /// Existing [`TypeDef`]
784    #[error("Extension already has an type called {0}.")]
785    TypeDefExists(TypeName),
786    /// Existing [`ExtensionValue`]
787    #[error("Extension already has an extension value called {0}.")]
788    ValueExists(ValueName),
789}
790
791/// A set of extensions identified by their unique [`ExtensionId`].
792#[derive(
793    Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
794)]
795#[display("[{}]", _0.iter().join(", "))]
796pub struct ExtensionSet(BTreeSet<ExtensionId>);
797
798/// A special ExtensionId which indicates that the delta of a non-Function
799/// container node should be computed by extension inference.
800///
801/// See [`infer_extensions`] which lists the container nodes to which this can be applied.
802///
803/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions
804pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED");
805
806impl ExtensionSet {
807    /// Creates a new empty extension set.
808    pub const fn new() -> Self {
809        Self(BTreeSet::new())
810    }
811
812    /// Adds a extension to the set.
813    pub fn insert(&mut self, extension: ExtensionId) {
814        self.0.insert(extension.clone());
815    }
816
817    /// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set
818    pub fn insert_type_var(&mut self, idx: usize) {
819        // Represent type vars as string representation of variable index.
820        // This is not a legal IdentList or ExtensionId so should not conflict.
821        self.0
822            .insert(ExtensionId::new_unchecked(idx.to_string().as_str()));
823    }
824
825    /// Returns `true` if the set contains the given extension.
826    pub fn contains(&self, extension: &ExtensionId) -> bool {
827        self.0.contains(extension)
828    }
829
830    /// Returns `true` if the set is a subset of `other`.
831    pub fn is_subset(&self, other: &Self) -> bool {
832        self.0.is_subset(&other.0)
833    }
834
835    /// Returns `true` if the set is a superset of `other`.
836    pub fn is_superset(&self, other: &Self) -> bool {
837        self.0.is_superset(&other.0)
838    }
839
840    /// Create a extension set with a single element.
841    pub fn singleton(extension: ExtensionId) -> Self {
842        let mut set = Self::new();
843        set.insert(extension);
844        set
845    }
846
847    /// An ExtensionSet containing a single type variable
848    /// (which must have been declared as a [TypeParam::Extensions])
849    pub fn type_var(idx: usize) -> Self {
850        let mut set = Self::new();
851        set.insert_type_var(idx);
852        set
853    }
854
855    /// Returns the union of two extension sets.
856    pub fn union(mut self, other: Self) -> Self {
857        self.0.extend(other.0);
858        self
859    }
860
861    /// Returns the union of an arbitrary collection of [ExtensionSet]s
862    pub fn union_over(sets: impl IntoIterator<Item = Self>) -> Self {
863        // `union` clones the receiver, which we do not need to do here
864        let mut res = ExtensionSet::new();
865        for s in sets {
866            res.0.extend(s.0)
867        }
868        res
869    }
870
871    /// The things in other which are in not in self
872    pub fn missing_from(&self, other: &Self) -> Self {
873        ExtensionSet::from_iter(other.0.difference(&self.0).cloned())
874    }
875
876    /// Iterate over the contained ExtensionIds
877    pub fn iter(&self) -> impl Iterator<Item = &ExtensionId> {
878        self.0.iter()
879    }
880
881    /// True if this set contains no [ExtensionId]s
882    pub fn is_empty(&self) -> bool {
883        self.0.is_empty()
884    }
885
886    pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> {
887        self.iter()
888            .filter_map(as_typevar)
889            .try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions))
890    }
891
892    pub(crate) fn substitute(&self, t: &Substitution) -> Self {
893        Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) {
894            None => vec![e.clone()],
895            Some(i) => match t.apply_var(i, &TypeParam::Extensions) {
896                TypeArg::Extensions{es} => es.iter().cloned().collect::<Vec<_>>(),
897                _ => panic!("value for type var was not extension set - type scheme should be validated first"),
898            },
899        }))
900    }
901}
902
903impl From<ExtensionId> for ExtensionSet {
904    fn from(id: ExtensionId) -> Self {
905        Self::singleton(id)
906    }
907}
908
909impl IntoIterator for ExtensionSet {
910    type Item = ExtensionId;
911    type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;
912
913    fn into_iter(self) -> Self::IntoIter {
914        self.0.into_iter()
915    }
916}
917
918impl<'a> IntoIterator for &'a ExtensionSet {
919    type Item = &'a ExtensionId;
920    type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;
921
922    fn into_iter(self) -> Self::IntoIter {
923        self.0.iter()
924    }
925}
926
927fn as_typevar(e: &ExtensionId) -> Option<usize> {
928    // Type variables are represented as radix-10 numbers, which are illegal
929    // as standard ExtensionIds. Hence if an ExtensionId starts with a digit,
930    // we assume it must be a type variable, and fail fast if it isn't.
931    match e.chars().next() {
932        Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()),
933        _ => None,
934    }
935}
936
937impl FromIterator<ExtensionId> for ExtensionSet {
938    fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
939        Self(BTreeSet::from_iter(iter))
940    }
941}
942
943/// Extension tests.
944#[cfg(test)]
945pub mod test {
946    // We re-export this here because mod op_def is private.
947    pub use super::op_def::test::SimpleOpDef;
948
949    use super::*;
950
951    impl Extension {
952        /// Create a new extension for testing, with a 0 version.
953        pub(crate) fn new_test_arc(
954            name: ExtensionId,
955            init: impl FnOnce(&mut Extension, &Weak<Extension>),
956        ) -> Arc<Self> {
957            Self::new_arc(name, Version::new(0, 0, 0), init)
958        }
959
960        /// Create a new extension for testing, with a 0 version.
961        pub(crate) fn try_new_test_arc(
962            name: ExtensionId,
963            init: impl FnOnce(
964                &mut Extension,
965                &Weak<Extension>,
966            ) -> Result<(), Box<dyn std::error::Error>>,
967        ) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
968            Self::try_new_arc(name, Version::new(0, 0, 0), init)
969        }
970    }
971
972    #[test]
973    fn test_register_update() {
974        // Two registers that should remain the same.
975        // We use them to test both `register_updated` and `register_updated_ref`.
976        let mut reg = ExtensionRegistry::default();
977        let mut reg_ref = ExtensionRegistry::default();
978
979        let ext_1_id = ExtensionId::new("ext1").unwrap();
980        let ext_2_id = ExtensionId::new("ext2").unwrap();
981        let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
982        let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
983        let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
984        let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));
985
986        reg.register(ext1.clone()).unwrap();
987        reg_ref.register(ext1.clone()).unwrap();
988        assert_eq!(&reg, &reg_ref);
989
990        // normal registration fails
991        assert_eq!(
992            reg.register(ext1_1.clone()),
993            Err(ExtensionRegistryError::AlreadyRegistered(
994                ext_1_id.clone(),
995                Version::new(1, 0, 0),
996                Version::new(1, 1, 0)
997            ))
998        );
999
1000        // register with update works
1001        reg_ref.register_updated_ref(&ext1_1);
1002        reg.register_updated(ext1_1.clone());
1003        assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
1004        assert_eq!(&reg, &reg_ref);
1005
1006        // register with lower version does not change version
1007        reg_ref.register_updated_ref(&ext1_2);
1008        reg.register_updated(ext1_2.clone());
1009        assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
1010        assert_eq!(&reg, &reg_ref);
1011
1012        reg.register(ext2.clone()).unwrap();
1013        assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
1014        assert_eq!(reg.len(), 2);
1015
1016        assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
1017        assert_eq!(reg.len(), 1);
1018    }
1019
1020    mod proptest {
1021
1022        use ::proptest::{collection::hash_set, prelude::*};
1023
1024        use super::super::{ExtensionId, ExtensionSet};
1025
1026        impl Arbitrary for ExtensionSet {
1027            type Parameters = ();
1028            type Strategy = BoxedStrategy<Self>;
1029
1030            fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1031                (
1032                    hash_set(0..10usize, 0..3),
1033                    hash_set(any::<ExtensionId>(), 0..3),
1034                )
1035                    .prop_map(|(vars, extensions)| {
1036                        ExtensionSet::union_over(
1037                            std::iter::once(extensions.into_iter().collect::<ExtensionSet>())
1038                                .chain(vars.into_iter().map(ExtensionSet::type_var)),
1039                        )
1040                    })
1041                    .boxed()
1042            }
1043        }
1044    }
1045}