hugr_core/envelope/
description.rs

1//! Description of the contents of a HUGR envelope used for debugging and error reporting.
2use crate::{
3    HugrView, Node,
4    envelope::{EnvelopeHeader, USED_EXTENSIONS_KEY},
5    ops::{DataflowOpTrait, OpType},
6};
7use itertools::Itertools;
8use semver::Version;
9
10type OptionVec<T> = Vec<Option<T>>;
11fn set_option_vec_len<T: Clone>(vec: &mut OptionVec<T>, n: usize) {
12    vec.resize(n, None);
13}
14fn set_option_vec_index<T: Clone>(vec: &mut OptionVec<T>, index: usize, value: T) {
15    if index >= vec.len() {
16        set_option_vec_len(vec, index + 1);
17    }
18    vec[index] = Some(value);
19}
20
21fn extend_option_vec<T: Clone>(vec: &mut Option<Vec<T>>, items: impl IntoIterator<Item = T>) {
22    if let Some(existing) = vec {
23        existing.extend(items);
24    } else {
25        vec.replace(items.into_iter().collect());
26    }
27}
28
29/// High-level description of a HUGR package.
30#[derive(Debug, Clone, PartialEq, Default, serde::Serialize, schemars::JsonSchema)]
31pub struct PackageDesc {
32    /// Envelope header information.
33    #[serde(serialize_with = "header_serialize")]
34    #[schemars(with = "String")]
35    pub header: EnvelopeHeader,
36    /// Description of the modules in the package.
37    pub modules: OptionVec<ModuleDesc>,
38    /// Description of the extensions in the package.
39    #[serde(skip_serializing_if = "Vec::is_empty")]
40    #[serde(default)]
41    pub packaged_extensions: OptionVec<ExtensionDesc>,
42}
43
44fn header_serialize<S>(header: &EnvelopeHeader, serializer: S) -> Result<S::Ok, S::Error>
45where
46    S: serde::Serializer,
47{
48    serializer.serialize_str(&header.to_string())
49}
50
51impl PackageDesc {
52    /// Creates a new `PackageDesc` with the given header.
53    pub(super) fn new(header: EnvelopeHeader) -> Self {
54        Self {
55            header,
56            ..Default::default()
57        }
58    }
59
60    /// Sets the number of modules in the package.
61    pub(crate) fn set_n_modules(&mut self, n: usize) {
62        set_option_vec_len(&mut self.modules, n);
63    }
64
65    /// Returns the package header.
66    pub fn header(&self) -> EnvelopeHeader {
67        self.header
68    }
69
70    /// Returns the number of modules in the package.
71    pub fn n_modules(&self) -> usize {
72        self.modules.len()
73    }
74
75    /// Sets a module description at the specified index.
76    pub(crate) fn set_module(&mut self, index: usize, module: impl Into<ModuleDesc>) {
77        set_option_vec_index(&mut self.modules, index, module.into());
78    }
79
80    /// Sets a packaged extension description at the specified index.
81    pub(crate) fn set_packaged_extension(&mut self, index: usize, ext: impl Into<ExtensionDesc>) {
82        set_option_vec_index(&mut self.packaged_extensions, index, ext.into());
83    }
84
85    /// Returns the number of packaged extensions in the package.
86    pub fn n_packaged_extensions(&self) -> usize {
87        self.packaged_extensions.len()
88    }
89
90    /// Returns the generator(s) of the package modules, if any.
91    /// Concatenates multiple generators with commas.
92    pub fn generator(&self) -> Option<String> {
93        let generators: Vec<String> = self
94            .modules
95            .iter()
96            .flatten()
97            .flat_map(|m| &m.generator)
98            .unique()
99            .cloned()
100            .collect();
101        if generators.is_empty() {
102            return None;
103        }
104
105        Some(generators.join(", "))
106    }
107
108    /// Returns an iterator over the module descriptions. Modules with
109    /// expected but missing descriptions yield `None`.
110    pub fn modules(&self) -> impl Iterator<Item = &Option<ModuleDesc>> {
111        self.modules.iter()
112    }
113
114    /// Returns an iterator over the packaged extension descriptions. Missing extensions are skipped.
115    pub fn packaged_extensions(&self) -> impl Iterator<Item = &ExtensionDesc> {
116        self.packaged_extensions.iter().flatten()
117    }
118}
119
120/// High level description of an extension.
121#[derive(
122    derive_more::Display,
123    Debug,
124    Clone,
125    PartialEq,
126    serde::Deserialize,
127    serde::Serialize,
128    schemars::JsonSchema,
129)]
130#[display("Extension {name} v{version}")]
131pub struct ExtensionDesc {
132    /// Name of the extension.
133    pub name: String,
134    /// Version of the extension.
135    #[schemars(with = "String")]
136    pub version: Version,
137}
138
139impl ExtensionDesc {
140    /// Create a new extension description.
141    pub fn new(name: impl ToString, version: impl Into<Version>) -> Self {
142        Self {
143            name: name.to_string(),
144            version: version.into(),
145        }
146    }
147}
148
149impl<E: AsRef<crate::Extension>> From<&E> for ExtensionDesc {
150    fn from(ext: &E) -> Self {
151        let ext = ext.as_ref();
152        Self {
153            name: ext.name.to_string(),
154            version: ext.version.clone(),
155        }
156    }
157}
158
159#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
160/// Description of the entrypoint of a module.
161pub struct Entrypoint {
162    /// Node id of the entrypoint.
163    #[schemars(with = "u32")]
164    pub node: Node,
165    #[schemars(with = "String")]
166    #[serde(serialize_with = "op_serialize")]
167    /// Operation type of the entrypoint node.
168    pub optype: OpType,
169}
170
171impl Entrypoint {
172    /// Create a new entrypoint description.
173    pub fn new(node: Node, optype: OpType) -> Self {
174        Self { node, optype }
175    }
176}
177
178/// Get a string representation of an OpType for description purposes.
179pub fn op_string(op: &OpType) -> String {
180    match op {
181        OpType::FuncDefn(defn) => format!(
182            "FuncDefn({})",
183            func_symbol(defn.func_name(), defn.signature())
184        ),
185        OpType::FuncDecl(decl) => format!(
186            "FuncDecl({})",
187            func_symbol(decl.func_name(), decl.signature())
188        ),
189        OpType::DFG(dfg) => format!("DFG({})", dfg.signature()),
190        _ => format!("{op}"),
191    }
192}
193fn op_serialize<S>(op_type: &OpType, serializer: S) -> Result<S::Ok, S::Error>
194where
195    S: serde::Serializer,
196{
197    serializer.serialize_str(op_string(op_type).as_str())
198}
199
200#[derive(
201    Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize, schemars::JsonSchema,
202)]
203/// High-level description of a module in a HUGR package.
204pub struct ModuleDesc {
205    /// Number of nodes in the module.
206    #[serde(skip_serializing_if = "Option::is_none")]
207    #[serde(default)]
208    pub num_nodes: Option<usize>,
209    /// The entrypoint node and the corresponding operation type.
210    #[serde(skip_serializing_if = "Option::is_none")]
211    #[serde(default)]
212    pub entrypoint: Option<Entrypoint>,
213    /// Extensions used in the module computed while resolving, expected to be a subset of `used_extensions_generator`.
214    #[serde(skip_serializing_if = "Option::is_none")]
215    #[serde(default)]
216    pub used_extensions_resolved: Option<Vec<ExtensionDesc>>,
217    /// Generator specified in the module metadata.
218    #[serde(skip_serializing_if = "Option::is_none")]
219    #[serde(default)]
220    pub generator: Option<String>,
221    /// Generator specified used extensions in the module metadata.
222    #[serde(skip_serializing_if = "Option::is_none")]
223    #[serde(default)]
224    pub used_extensions_generator: Option<Vec<ExtensionDesc>>,
225    /// Public symbols defined in the module.
226    #[serde(skip_serializing_if = "Option::is_none")]
227    #[serde(default)]
228    pub public_symbols: Option<Vec<String>>,
229}
230
231impl ModuleDesc {
232    /// Sets the number of nodes in the module.
233    pub fn set_num_nodes(&mut self, num_nodes: usize) {
234        self.num_nodes = Some(num_nodes);
235    }
236
237    /// Sets the entrypoint of the module.
238    pub fn set_entrypoint(&mut self, node: Node, optype: OpType) {
239        self.entrypoint = Some(Entrypoint::new(node, optype));
240    }
241
242    /// Sets the generator for the module.
243    pub fn set_generator(&mut self, generator: impl Into<String>) {
244        self.generator = Some(generator.into());
245    }
246
247    /// Sets the extensions used by the generator in the module metadata.
248    pub fn set_used_extensions_generator(
249        &mut self,
250        used_extensions_metadata: impl IntoIterator<Item = ExtensionDesc>,
251    ) {
252        self.used_extensions_generator = Some(used_extensions_metadata.into_iter().collect());
253    }
254
255    /// Extends the extensions used by the generator in the module metadata.
256    pub fn extend_used_extensions_metadata(
257        &mut self,
258        exts: impl IntoIterator<Item = ExtensionDesc>,
259    ) {
260        extend_option_vec(&mut self.used_extensions_generator, exts);
261    }
262
263    /// Sets the resolved extensions used in the module.
264    pub fn set_used_extensions_resolved(
265        &mut self,
266        used_extensions_resolved: impl IntoIterator<Item = ExtensionDesc>,
267    ) {
268        self.used_extensions_resolved = Some(used_extensions_resolved.into_iter().collect());
269    }
270
271    /// Extends the resolved extensions used in the module.
272    pub fn extend_used_extensions_resolved(
273        &mut self,
274        exts: impl IntoIterator<Item = ExtensionDesc>,
275    ) {
276        extend_option_vec(&mut self.used_extensions_resolved, exts);
277    }
278
279    /// Sets the public symbols defined in the module.
280    pub fn set_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
281        self.public_symbols = Some(symbols.into_iter().collect());
282    }
283
284    /// Extends the public symbols defined in the module.
285    pub fn extend_public_symbols(&mut self, symbols: impl IntoIterator<Item = String>) {
286        extend_option_vec(&mut self.public_symbols, symbols);
287    }
288
289    /// Loads the generator from the HUGR metadata.
290    pub(crate) fn load_generator(&mut self, hugr: &impl HugrView) {
291        if let Some(val) = hugr.get_metadata(hugr.module_root(), crate::envelope::GENERATOR_KEY) {
292            self.set_generator(super::format_generator(val));
293        }
294    }
295
296    /// Loads the extensions used by the generator from the HUGR metadata.
297    pub(crate) fn load_used_extensions_generator(
298        &mut self,
299        hugr: &impl HugrView,
300    ) -> Result<(), serde_json::Error> {
301        let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
302            return Ok(()); // No used extensions metadata, nothing to check
303        };
304        let used_exts: Vec<ExtensionDesc> = serde_json::from_value(exts.clone())?;
305
306        self.set_used_extensions_generator(used_exts);
307        Ok(())
308    }
309
310    /// Loads the resolved extensions used in the module from the HUGR.
311    pub(crate) fn load_used_extensions_resolved(&mut self, hugr: &impl HugrView) {
312        self.set_used_extensions_resolved(
313            hugr.extensions()
314                .iter()
315                .map(|ext| ExtensionDesc::new(&ext.name, ext.version.clone())),
316        )
317    }
318
319    /// Loads the public symbols defined in the module from the HUGR.
320    pub(crate) fn load_public_symbols(&mut self, hugr: &impl HugrView) {
321        let symbols = hugr
322            .children(hugr.module_root())
323            .filter_map(|n| match hugr.get_optype(n) {
324                OpType::FuncDecl(decl) if *decl.visibility() == crate::Visibility::Public => {
325                    Some(func_symbol(decl.func_name(), decl.signature()))
326                }
327                OpType::FuncDefn(defn) if *defn.visibility() == crate::Visibility::Public => {
328                    Some(func_symbol(defn.func_name(), defn.signature()))
329                }
330                _ => None,
331            });
332
333        self.set_public_symbols(symbols);
334    }
335
336    /// Loads the entrypoint of the module from the HUGR.
337    pub(crate) fn load_entrypoint(&mut self, hugr: &impl HugrView<Node = Node>) {
338        let node = hugr.entrypoint();
339        self.set_entrypoint(node, hugr.get_optype(node).clone());
340    }
341
342    /// Loads the number of nodes in the module from the HUGR.
343    pub(crate) fn load_num_nodes(&mut self, hugr: &impl HugrView) {
344        self.set_num_nodes(hugr.num_nodes());
345    }
346
347    /// Loads full description of the module from the HUGR.
348    pub(crate) fn load_from_hugr(&mut self, hugr: &impl HugrView<Node = Node>) {
349        self.load_num_nodes(hugr);
350        self.load_entrypoint(hugr);
351        self.load_generator(hugr);
352        self.load_used_extensions_resolved(hugr);
353        self.load_public_symbols(hugr);
354        // invalid used extensions metadata is ignored here, treated as not present
355        self.load_used_extensions_generator(hugr).ok();
356    }
357}
358
359fn func_symbol(name: &str, signature: &crate::types::PolyFuncType) -> String {
360    format!("{name}: {}", signature)
361}
362impl<H: HugrView<Node = Node>> From<&H> for ModuleDesc {
363    fn from(hugr: &H) -> Self {
364        let mut desc = ModuleDesc::default();
365        desc.load_from_hugr(hugr);
366        desc
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use super::*;
373    use rstest::{fixture, rstest};
374    use semver::Version;
375
376    #[fixture]
377    fn empty_package_desc() -> PackageDesc {
378        PackageDesc::default()
379    }
380
381    #[fixture]
382    fn empty_module_desc() -> ModuleDesc {
383        ModuleDesc::default()
384    }
385
386    #[fixture]
387    fn test_extension() -> ExtensionDesc {
388        ExtensionDesc::new("test_ext", Version::new(1, 0, 0))
389    }
390
391    #[rstest]
392    fn test_package_desc_new() {
393        let header = EnvelopeHeader::default();
394        let package = PackageDesc::new(header);
395        assert_eq!(package.header(), header);
396        assert_eq!(package.n_modules(), 0);
397        assert_eq!(package.n_packaged_extensions(), 0);
398    }
399
400    #[rstest]
401    fn test_package_desc_set_n_modules(mut empty_package_desc: PackageDesc) {
402        empty_package_desc.set_n_modules(5);
403        assert_eq!(empty_package_desc.n_modules(), 5);
404    }
405
406    #[rstest]
407    fn test_package_desc_set_module(
408        mut empty_package_desc: PackageDesc,
409        empty_module_desc: ModuleDesc,
410    ) {
411        empty_package_desc.set_module(0, empty_module_desc.clone());
412        assert_eq!(
413            empty_package_desc.modules().next().unwrap().as_ref(),
414            Some(&empty_module_desc)
415        );
416    }
417
418    #[rstest]
419    fn test_package_desc_set_packaged_extension(
420        mut empty_package_desc: PackageDesc,
421        test_extension: ExtensionDesc,
422    ) {
423        empty_package_desc.set_packaged_extension(0, test_extension.clone());
424        assert_eq!(
425            empty_package_desc.packaged_extensions().next(),
426            Some(&test_extension)
427        );
428    }
429
430    #[rstest]
431    fn test_package_desc_generator(mut empty_package_desc: PackageDesc) {
432        let mut module = ModuleDesc::default();
433        module.set_generator("test_generator");
434        empty_package_desc.set_module(0, module);
435        assert_eq!(
436            empty_package_desc.generator(),
437            Some("test_generator".to_string())
438        );
439    }
440
441    #[rstest]
442    fn test_module_desc_set_num_nodes(mut empty_module_desc: ModuleDesc) {
443        empty_module_desc.set_num_nodes(10);
444        assert_eq!(empty_module_desc.num_nodes, Some(10));
445    }
446
447    #[rstest]
448    fn test_module_desc_set_entrypoint(mut empty_module_desc: ModuleDesc) {
449        let node = Node::from(portgraph::NodeIndex::new(0));
450        let optype: OpType = crate::ops::DFG {
451            signature: Default::default(),
452        }
453        .into();
454        empty_module_desc.set_entrypoint(node, optype.clone());
455        assert_eq!(empty_module_desc.entrypoint.as_ref().unwrap().node, node);
456        assert_eq!(
457            empty_module_desc.entrypoint.as_ref().unwrap().optype,
458            optype
459        );
460    }
461
462    #[rstest]
463    #[case("test_generator", Some("test_generator".to_string()))]
464    #[case("", None)]
465    fn test_module_desc_generator(#[case] input: &str, #[case] expected: Option<String>) {
466        let mut module = ModuleDesc::default();
467        if !input.is_empty() {
468            module.set_generator(input);
469        }
470        assert_eq!(module.generator, expected);
471    }
472
473    #[test]
474    fn test_extension_desc_new() {
475        let name = "test_extension";
476        let version = Version::new(1, 0, 0);
477        let extension = ExtensionDesc::new(name, version.clone());
478        assert_eq!(extension.name, name);
479        assert_eq!(extension.version, version);
480    }
481
482    #[rstest]
483    fn test_package_desc_n_packaged_extensions(
484        mut empty_package_desc: PackageDesc,
485        test_extension: ExtensionDesc,
486    ) {
487        assert_eq!(empty_package_desc.n_packaged_extensions(), 0);
488
489        empty_package_desc.set_packaged_extension(0, test_extension);
490        assert_eq!(empty_package_desc.n_packaged_extensions(), 1);
491    }
492
493    #[rstest]
494    fn test_package_desc_modules_iterator(
495        mut empty_package_desc: PackageDesc,
496        empty_module_desc: ModuleDesc,
497    ) {
498        empty_package_desc.set_module(0, empty_module_desc.clone());
499
500        let modules: Vec<_> = empty_package_desc.modules().collect();
501        assert_eq!(modules.len(), 1);
502        assert_eq!(modules[0].as_ref(), Some(&empty_module_desc));
503    }
504
505    #[rstest]
506    fn test_package_desc_packaged_extensions_iterator(
507        mut empty_package_desc: PackageDesc,
508        test_extension: ExtensionDesc,
509    ) {
510        empty_package_desc.set_packaged_extension(0, test_extension.clone());
511
512        let extensions: Vec<_> = empty_package_desc.packaged_extensions().collect();
513        assert_eq!(extensions.len(), 1);
514        assert_eq!(extensions[0], &test_extension);
515    }
516
517    #[rstest]
518    fn test_module_desc_set_used_extensions_generator(
519        mut empty_module_desc: ModuleDesc,
520        test_extension: ExtensionDesc,
521    ) {
522        empty_module_desc.set_used_extensions_generator(vec![test_extension.clone()]);
523
524        assert_eq!(
525            empty_module_desc
526                .used_extensions_generator
527                .as_ref()
528                .unwrap()
529                .len(),
530            1
531        );
532        assert_eq!(
533            empty_module_desc
534                .used_extensions_generator
535                .as_ref()
536                .unwrap()[0],
537            test_extension
538        );
539    }
540
541    #[rstest]
542    fn test_module_desc_extend_used_extensions_metadata(mut empty_module_desc: ModuleDesc) {
543        let extension1 = ExtensionDesc::new("test_ext1", Version::new(1, 0, 0));
544        let extension2 = ExtensionDesc::new("test_ext2", Version::new(2, 0, 0));
545
546        empty_module_desc.set_used_extensions_generator(vec![extension1.clone()]);
547        empty_module_desc.extend_used_extensions_metadata(vec![extension2.clone()]);
548
549        let extensions = empty_module_desc
550            .used_extensions_generator
551            .as_ref()
552            .unwrap();
553        assert_eq!(extensions.len(), 2);
554        assert!(extensions.contains(&extension1));
555        assert!(extensions.contains(&extension2));
556    }
557
558    #[rstest]
559    fn test_module_desc_set_public_symbols(mut empty_module_desc: ModuleDesc) {
560        let symbols = vec!["symbol1".to_string(), "symbol2".to_string()];
561        empty_module_desc.set_public_symbols(symbols.clone());
562
563        assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap().len(), 2);
564        assert_eq!(empty_module_desc.public_symbols.as_ref().unwrap(), &symbols);
565    }
566
567    #[rstest]
568    fn test_module_desc_extend_public_symbols(mut empty_module_desc: ModuleDesc) {
569        let symbols1 = vec!["symbol1".to_string()];
570        let symbols2 = vec!["symbol2".to_string()];
571
572        empty_module_desc.set_public_symbols(symbols1.clone());
573        empty_module_desc.extend_public_symbols(symbols2.clone());
574
575        let symbols = empty_module_desc.public_symbols.as_ref().unwrap();
576        assert_eq!(symbols.len(), 2);
577        assert!(symbols.contains(&"symbol1".to_string()));
578        assert!(symbols.contains(&"symbol2".to_string()));
579    }
580}